Neural Quantum States
The nqs subpackage provides neural quantum state architectures that
parameterize the many-body wavefunction as a neural network.
Base Class
- class qvartools.nqs.neural_state.NeuralQuantumState(num_sites, local_dim=2, complex_output=False)[source]
-
Abstract base class for all neural quantum state ansaetze.
Every subclass must implement
log_amplitude()andphase(). The base class provides convenience methods for evaluating the full log-wavefunction, the wavefunction itself, Born-rule probabilities, and normalised probabilities over a discrete basis set.- Parameters:
num_sites (
int) – Number of lattice / orbital sites (input dimensionality).local_dim (
int, optional) – Dimension of the local Hilbert space on each site (default2for spin-1/2 / qubit systems).complex_output (
bool, optional) – IfTruethe NQS represents a complex-valued wavefunction with a non-trivial phase network. IfFalsethe phase is identically zero andlog_psi()returns a single real tensor.
- abstractmethod log_amplitude(x)[source]
Compute the log-amplitude ln|psi(x)| for each configuration.
- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Log-amplitudes, shape
(batch,).- Return type:
- abstractmethod phase(x)[source]
Compute the phase arg(psi(x)) for each configuration.
- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Phases in radians, shape
(batch,). Must be identically zero whencomplex_output is False.- Return type:
- log_psi(x)[source]
Compute the log-wavefunction.
For real-valued NQS (
complex_output is False), returns only the log-amplitude. For complex-valued NQS, returns a tuple of (log_amplitude, phase).- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
If
complex_output is False: log-amplitude tensor of shape(batch,). Ifcomplex_output is True: tuple(log_amp, phase)each of shape(batch,).- Return type:
torch.Tensorortupleoftorch.Tensor
- psi(x)[source]
Evaluate the full wavefunction psi(x).
Computes
exp(log_amp) * exp(i * phase). Whencomplex_output is Falsethe result is real-valued (dtype matches input); otherwise it is complex.- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Wavefunction values, shape
(batch,). Complex dtype whencomplex_output is True.- Return type:
- probability(x)[source]
Compute the Born-rule probability |psi(x)|^2 (unnormalised).
- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Unnormalised probabilities
exp(2 * log_amplitude(x)), shape(batch,).- Return type:
- normalized_probability(x, basis_set)[source]
Compute normalised Born-rule probabilities over a basis set.
The normalisation constant Z is computed as the sum of
|psi(s)|^2over every configuration s in basis_set.- Parameters:
x (
torch.Tensor) – Configurations to evaluate, shape(batch, num_sites).basis_set (
torch.Tensor) – Complete (or reference) set of configurations used to compute the partition function, shape(n_basis, num_sites).
- Returns:
Normalised probabilities, shape
(batch,). Each entry is|psi(x_i)|^2 / Z.- Return type:
- static encode_configuration(config)[source]
Convert a configuration tensor to float for network input.
- Parameters:
config (
torch.Tensor) – Configuration tensor of any integer or float dtype, shape(..., num_sites).- Returns:
Float tensor with the same shape, dtype
torch.float32.- Return type:
- forward(x)[source]
Forward pass — delegates to
log_psi().- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Same as
log_psi().- Return type:
torch.Tensorortupleoftorch.Tensor
Dense Architectures
- class qvartools.nqs.architectures.dense.DenseNQS(num_sites, hidden_dims=None, complex_output=False)[source]
Bases:
NeuralQuantumStateFully connected feedforward neural quantum state.
The amplitude network maps a configuration vector to a scalar log-amplitude via a stack of Linear + ReLU layers, followed by a final Linear + Tanh layer whose output is scaled by a learnable
log_amp_scaleparameter.If
complex_outputisTrue, a separate phase network of the same depth produces the wavefunction phase in(-pi, pi).- Parameters:
- amplitude_net
The amplitude MLP (output before scaling).
- Type:
nn.Sequential
- log_amp_scale
Learnable scalar that multiplies the Tanh output.
- Type:
nn.Parameter
Examples
>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32]) >>> x = torch.randint(0, 2, (8, 10)).float() >>> log_amp = nqs.log_amplitude(x) # shape (8,)
- log_amplitude(x)[source]
Compute log-amplitude ln|psi(x)|.
The raw amplitude network output (bounded in
(-1, 1)by Tanh) is multiplied by the learnablelog_amp_scale.- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Log-amplitudes, shape
(batch,).- Return type:
- phase(x)[source]
Compute the wavefunction phase.
Returns zeros for real-valued NQS. For complex NQS the phase network output (in
(-1, 1)) is scaled bypiso the phase lies in(-pi, pi).- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Phases in radians, shape
(batch,).- Return type:
- class qvartools.nqs.architectures.dense.SignedDenseNQS(num_sites, hidden_dims=None)[source]
Bases:
NeuralQuantumStateDense NQS with explicit sign structure.
Uses a shared feature extractor whose output feeds into two heads:
Amplitude head — produces the log-amplitude via Softplus to ensure non-negative output.
Sign head — produces a logit whose sigmoid is thresholded at 0.5 to yield a phase of either 0 (positive) or pi (negative).
Feature caching avoids redundant computation when
log_amplitude()andphase()are called on the same input within one evaluation.- Parameters:
- feature_net
Shared feature extractor.
- Type:
nn.Sequential
- amplitude_head
Maps features to log-amplitude (Softplus output).
- Type:
nn.Sequential
- sign_head
Maps features to a sign logit.
- Type:
nn.Linear
Examples
>>> nqs = SignedDenseNQS(num_sites=10) >>> x = torch.randint(0, 2, (8, 10)).float() >>> log_amp = nqs.log_amplitude(x) # shape (8,) >>> phi = nqs.phase(x) # shape (8,), values in {0, pi}
- clear_feature_cache()[source]
Clear the feature cache.
Call this between training steps or whenever the input batch changes to avoid stale cached values.
- Return type:
None
- log_amplitude(x)[source]
Compute log-amplitude from the amplitude head.
The Softplus activation ensures the raw amplitude is non-negative; the log-amplitude is the logarithm of that value.
- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Log-amplitudes, shape
(batch,).- Return type:
- phase(x)[source]
Compute the sign-derived phase.
The sign logit is passed through a sigmoid. Values above 0.5 correspond to a positive sign (phase = 0); values below 0.5 correspond to a negative sign (phase = pi).
During training, a soft interpolation is used for gradient flow:
phase = pi * (1 - sigmoid(sign_logit)).- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Phase values, shape
(batch,). During training these are continuous in(0, pi); at eval they snap to{0, pi}.- Return type:
- class qvartools.nqs.architectures.complex_nqs.ComplexNQS(num_sites, hidden_dims=None)[source]
Bases:
NeuralQuantumStateComplex-valued NQS with shared feature extractor.
A shared MLP backbone produces feature vectors that are fed into two independent heads:
Amplitude head — maps features to a scalar log-amplitude.
Phase head — maps features to a phase in
(-pi, pi).
Feature caching avoids redundant computation when
log_amplitude()andphase()are called sequentially on the same input tensor.- Parameters:
- feature_net
Shared feature MLP.
- Type:
nn.Sequential
- amplitude_head
Linear projection from features to scalar log-amplitude.
- Type:
nn.Linear
- phase_head
Maps features to phase via Linear + Tanh (scaled by pi).
- Type:
nn.Sequential
Examples
>>> nqs = ComplexNQS(num_sites=10, hidden_dims=[64, 32]) >>> x = torch.randint(0, 2, (8, 10)).float() >>> log_amp = nqs.log_amplitude(x) # shape (8,) >>> phi = nqs.phase(x) # shape (8,), in (-pi, pi)
- clear_feature_cache()[source]
Clear the feature cache.
Call this between training steps or when the input batch changes to avoid stale cached values.
- Return type:
None
- log_amplitude(x)[source]
Compute log-amplitude from shared features.
- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Log-amplitudes, shape
(batch,).- Return type:
- phase(x)[source]
Compute phase from shared features.
The Tanh output (in
(-1, 1)) is scaled bypito produce a phase in(-pi, pi).- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Phases in radians, shape
(batch,).- Return type:
Restricted Boltzmann Machine
- class qvartools.nqs.architectures.rbm.RBMQuantumState(num_sites, num_hidden=None, complex_weights=False)[source]
Bases:
NeuralQuantumStateRestricted Boltzmann Machine neural quantum state.
Implements the RBM ansatz of Carleo & Troyer (Science, 2017):
\[\psi(\mathbf{x}) = \exp\!\Bigl(\sum_j a_j x_j\Bigr) \prod_i \cosh\!\Bigl(b_i + \sum_j W_{ij} x_j\Bigr)\]When
complex_weightsisTrue, the parametersa,b, andWare complex-valued, and the wavefunction acquires a non-trivial phase.- Parameters:
- a_real
Real part of the visible bias, shape
(num_sites,).- Type:
nn.Parameter
- b_real
Real part of the hidden bias, shape
(num_hidden,).- Type:
nn.Parameter
- W_real
Real part of the weight matrix, shape
(num_hidden, num_sites).- Type:
nn.Parameter
Examples
>>> rbm = RBMQuantumState(num_sites=10, num_hidden=20) >>> x = torch.randint(0, 2, (8, 10)).float() >>> log_amp = rbm.log_amplitude(x) # shape (8,)
- log_amplitude(x)[source]
Compute the log-amplitude of the RBM wavefunction.
For real weights:
\[\ln|\psi(\mathbf{x})| = \mathrm{Re}(\mathbf{a}) \cdot \mathbf{x} + \sum_i \ln\cosh(\theta_i)\]For complex weights, takes the real part of the full log-wavefunction.
- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Log-amplitudes, shape
(batch,).- Return type:
- phase(x)[source]
Compute the wavefunction phase.
For real weights the phase is identically zero. For complex weights, the phase is the imaginary part of the full log-wavefunction:
\[\arg\psi(\mathbf{x}) = \mathrm{Im}(\mathbf{a}) \cdot \mathbf{x} + \sum_i \mathrm{Im}\bigl(\ln\cosh(\theta_i)\bigr)\]- Parameters:
x (
torch.Tensor) – Batch of configurations, shape(batch, num_sites).- Returns:
Phases in radians, shape
(batch,).- Return type:
Autoregressive Transformer
- class qvartools.nqs.transformer.autoregressive.AutoregressiveTransformer(n_orbitals, n_alpha, n_beta, embed_dim=64, n_heads=4, n_layers=4, dropout=0.0)[source]
Bases:
ModuleAutoregressive transformer NQS with alpha/beta spin channels.
Models the joint probability of occupying orbitals by factorising it autoregressively:
\[p(\mathbf{x}) = \prod_{i=1}^{N_{\text{orb}}} p(x^\alpha_i | x^\alpha_{<i}) \;\prod_{i=1}^{N_{\text{orb}}} p(x^\beta_i | x^\beta_{<i}, \mathbf{x}^\alpha)\]The alpha channel uses causal self-attention only. The beta channel uses causal self-attention plus cross-attention to the full alpha representation, enabling spin-spin correlations.
Sampling enforces particle conservation: exactly
n_alphaelectrons in the alpha channel andn_betain the beta channel.- Parameters:
n_orbitals (
int) – Number of spatial orbitals per spin channel.n_alpha (
int) – Number of alpha electrons.n_beta (
int) – Number of beta electrons.embed_dim (
int, optional) – Embedding dimensionality (default64).n_heads (
int, optional) – Number of attention heads (default4).n_layers (
int, optional) – Number of transformer layers per channel (default4).dropout (
float, optional) – Dropout probability (default0.0).
- alpha_blocks
Transformer blocks for the alpha channel (self-attention only).
- Type:
nn.ModuleList
- beta_blocks
Transformer blocks for the beta channel (self + cross-attention).
- Type:
nn.ModuleList
- alpha_head
Output head producing alpha occupation logits.
- Type:
nn.Linear
- beta_head
Output head producing beta occupation logits.
- Type:
nn.Linear
Examples
>>> model = AutoregressiveTransformer( ... n_orbitals=6, n_alpha=2, n_beta=2, embed_dim=32, n_heads=4 ... ) >>> configs = model.sample(n_samples=16) # shape (16, 12)
- log_prob(alpha, beta)[source]
Compute the log-probability of a configuration.
- Parameters:
alpha (
torch.Tensor) – Alpha spin-orbital occupations, shape(batch, n_orbitals)with entries in{0, 1}.beta (
torch.Tensor) – Beta spin-orbital occupations, shape(batch, n_orbitals)with entries in{0, 1}.
- Returns:
Log-probabilities, shape
(batch,).- Return type:
- sample(n_samples, temperature=1.0)[source]
Generate particle-conserving configurations autoregressively.
Samples alpha orbitals first (enforcing exactly
n_alphaelectrons), then samples beta orbitals with cross-attention to alpha (enforcing exactlyn_betaelectrons). The returned configuration is[alpha, beta]concatenated along the orbital axis.KV caching is used for efficient autoregressive generation.
- Parameters:
- Returns:
Sampled configurations, shape
(n_samples, 2 * n_orbitals)with entries in{0, 1}. The firstn_orbitalscolumns are alpha occupations and the lastn_orbitalsare beta.- Return type:
- forward(alpha, beta)[source]
Forward pass — delegates to
log_prob().- Parameters:
alpha (
torch.Tensor) – Alpha occupations, shape(batch, n_orbitals).beta (
torch.Tensor) – Beta occupations, shape(batch, n_orbitals).
- Returns:
Log-probabilities, shape
(batch,).- Return type:
Utilities
- qvartools.nqs.compile_nqs(model, mode='reduce-overhead')[source]
Apply
torch.compileto an NQS model with graceful fallback.- Parameters:
model (
nn.Module) – The neural quantum state model to compile.mode (
str, optional) – Compilation mode passed totorch.compile. Common choices are"reduce-overhead"(default) and"max-autotune".
- Returns:
The compiled model, or the original model unchanged if compilation fails (e.g. unsupported platform or PyTorch version).
- Return type:
nn.Module
Examples
>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32]) >>> nqs = compile_nqs(nqs, mode="max-autotune")