Neural Quantum States

The nqs subpackage provides neural quantum state architectures that parameterize the many-body wavefunction as a neural network.

Base Class

class qvartools.nqs.neural_state.NeuralQuantumState(num_sites, local_dim=2, complex_output=False)[source]

Bases: Module, ABC

Abstract base class for all neural quantum state ansaetze.

Every subclass must implement log_amplitude() and phase(). The base class provides convenience methods for evaluating the full log-wavefunction, the wavefunction itself, Born-rule probabilities, and normalised probabilities over a discrete basis set.

Parameters:
  • num_sites (int) – Number of lattice / orbital sites (input dimensionality).

  • local_dim (int, optional) – Dimension of the local Hilbert space on each site (default 2 for spin-1/2 / qubit systems).

  • complex_output (bool, optional) – If True the NQS represents a complex-valued wavefunction with a non-trivial phase network. If False the phase is identically zero and log_psi() returns a single real tensor.

num_sites

Number of sites.

Type:

int

local_dim

Local Hilbert-space dimension.

Type:

int

complex_output

Whether the NQS has a non-trivial phase.

Type:

bool

abstractmethod log_amplitude(x)[source]

Compute the log-amplitude ln|psi(x)| for each configuration.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

abstractmethod phase(x)[source]

Compute the phase arg(psi(x)) for each configuration.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,). Must be identically zero when complex_output is False.

Return type:

torch.Tensor

log_psi(x)[source]

Compute the log-wavefunction.

For real-valued NQS (complex_output is False), returns only the log-amplitude. For complex-valued NQS, returns a tuple of (log_amplitude, phase).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

If complex_output is False: log-amplitude tensor of shape (batch,). If complex_output is True: tuple (log_amp, phase) each of shape (batch,).

Return type:

torch.Tensor or tuple of torch.Tensor

psi(x)[source]

Evaluate the full wavefunction psi(x).

Computes exp(log_amp) * exp(i * phase). When complex_output is False the result is real-valued (dtype matches input); otherwise it is complex.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Wavefunction values, shape (batch,). Complex dtype when complex_output is True.

Return type:

torch.Tensor

probability(x)[source]

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Unnormalised probabilities exp(2 * log_amplitude(x)), shape (batch,).

Return type:

torch.Tensor

normalized_probability(x, basis_set)[source]

Compute normalised Born-rule probabilities over a basis set.

The normalisation constant Z is computed as the sum of |psi(s)|^2 over every configuration s in basis_set.

Parameters:
  • x (torch.Tensor) – Configurations to evaluate, shape (batch, num_sites).

  • basis_set (torch.Tensor) – Complete (or reference) set of configurations used to compute the partition function, shape (n_basis, num_sites).

Returns:

Normalised probabilities, shape (batch,). Each entry is |psi(x_i)|^2 / Z.

Return type:

torch.Tensor

static encode_configuration(config)[source]

Convert a configuration tensor to float for network input.

Parameters:

config (torch.Tensor) – Configuration tensor of any integer or float dtype, shape (..., num_sites).

Returns:

Float tensor with the same shape, dtype torch.float32.

Return type:

torch.Tensor

forward(x)[source]

Forward pass — delegates to log_psi().

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Same as log_psi().

Return type:

torch.Tensor or tuple of torch.Tensor

Dense Architectures

class qvartools.nqs.architectures.dense.DenseNQS(num_sites, hidden_dims=None, complex_output=False)[source]

Bases: NeuralQuantumState

Fully connected feedforward neural quantum state.

The amplitude network maps a configuration vector to a scalar log-amplitude via a stack of Linear + ReLU layers, followed by a final Linear + Tanh layer whose output is scaled by a learnable log_amp_scale parameter.

If complex_output is True, a separate phase network of the same depth produces the wavefunction phase in (-pi, pi).

Parameters:
  • num_sites (int) – Number of lattice / orbital sites.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the amplitude (and phase) networks (default [128, 64]).

  • complex_output (bool, optional) – Whether to include a phase network (default False).

amplitude_net

The amplitude MLP (output before scaling).

Type:

nn.Sequential

log_amp_scale

Learnable scalar that multiplies the Tanh output.

Type:

nn.Parameter

phase_net

Phase MLP when complex_output is True, else None.

Type:

nn.Sequential or None

Examples

>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32])
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x)  # shape (8,)
log_amplitude(x)[source]

Compute log-amplitude ln|psi(x)|.

The raw amplitude network output (bounded in (-1, 1) by Tanh) is multiplied by the learnable log_amp_scale.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute the wavefunction phase.

Returns zeros for real-valued NQS. For complex NQS the phase network output (in (-1, 1)) is scaled by pi so the phase lies in (-pi, pi).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,).

Return type:

torch.Tensor

class qvartools.nqs.architectures.dense.SignedDenseNQS(num_sites, hidden_dims=None)[source]

Bases: NeuralQuantumState

Dense NQS with explicit sign structure.

Uses a shared feature extractor whose output feeds into two heads:

  • Amplitude head — produces the log-amplitude via Softplus to ensure non-negative output.

  • Sign head — produces a logit whose sigmoid is thresholded at 0.5 to yield a phase of either 0 (positive) or pi (negative).

Feature caching avoids redundant computation when log_amplitude() and phase() are called on the same input within one evaluation.

Parameters:
  • num_sites (int) – Number of lattice / orbital sites.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the shared feature extractor (default [128, 64]).

feature_net

Shared feature extractor.

Type:

nn.Sequential

amplitude_head

Maps features to log-amplitude (Softplus output).

Type:

nn.Sequential

sign_head

Maps features to a sign logit.

Type:

nn.Linear

Examples

>>> nqs = SignedDenseNQS(num_sites=10)
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x)  # shape (8,)
>>> phi = nqs.phase(x)              # shape (8,), values in {0, pi}
clear_feature_cache()[source]

Clear the feature cache.

Call this between training steps or whenever the input batch changes to avoid stale cached values.

Return type:

None

log_amplitude(x)[source]

Compute log-amplitude from the amplitude head.

The Softplus activation ensures the raw amplitude is non-negative; the log-amplitude is the logarithm of that value.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute the sign-derived phase.

The sign logit is passed through a sigmoid. Values above 0.5 correspond to a positive sign (phase = 0); values below 0.5 correspond to a negative sign (phase = pi).

During training, a soft interpolation is used for gradient flow: phase = pi * (1 - sigmoid(sign_logit)).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phase values, shape (batch,). During training these are continuous in (0, pi); at eval they snap to {0, pi}.

Return type:

torch.Tensor

class qvartools.nqs.architectures.complex_nqs.ComplexNQS(num_sites, hidden_dims=None)[source]

Bases: NeuralQuantumState

Complex-valued NQS with shared feature extractor.

A shared MLP backbone produces feature vectors that are fed into two independent heads:

  • Amplitude head — maps features to a scalar log-amplitude.

  • Phase head — maps features to a phase in (-pi, pi).

Feature caching avoids redundant computation when log_amplitude() and phase() are called sequentially on the same input tensor.

Parameters:
  • num_sites (int) – Number of lattice / orbital sites.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the shared feature extractor (default [128, 64]).

feature_net

Shared feature MLP.

Type:

nn.Sequential

amplitude_head

Linear projection from features to scalar log-amplitude.

Type:

nn.Linear

phase_head

Maps features to phase via Linear + Tanh (scaled by pi).

Type:

nn.Sequential

Examples

>>> nqs = ComplexNQS(num_sites=10, hidden_dims=[64, 32])
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x)  # shape (8,)
>>> phi = nqs.phase(x)              # shape (8,), in (-pi, pi)
clear_feature_cache()[source]

Clear the feature cache.

Call this between training steps or when the input batch changes to avoid stale cached values.

Return type:

None

log_amplitude(x)[source]

Compute log-amplitude from shared features.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute phase from shared features.

The Tanh output (in (-1, 1)) is scaled by pi to produce a phase in (-pi, pi).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,).

Return type:

torch.Tensor

Restricted Boltzmann Machine

class qvartools.nqs.architectures.rbm.RBMQuantumState(num_sites, num_hidden=None, complex_weights=False)[source]

Bases: NeuralQuantumState

Restricted Boltzmann Machine neural quantum state.

Implements the RBM ansatz of Carleo & Troyer (Science, 2017):

\[\psi(\mathbf{x}) = \exp\!\Bigl(\sum_j a_j x_j\Bigr) \prod_i \cosh\!\Bigl(b_i + \sum_j W_{ij} x_j\Bigr)\]

When complex_weights is True, the parameters a, b, and W are complex-valued, and the wavefunction acquires a non-trivial phase.

Parameters:
  • num_sites (int) – Number of visible units (lattice / orbital sites).

  • num_hidden (int, optional) – Number of hidden units (default num_sites).

  • complex_weights (bool, optional) – If True, use complex-valued RBM parameters to represent a complex wavefunction (default False).

a_real

Real part of the visible bias, shape (num_sites,).

Type:

nn.Parameter

a_imag

Imaginary part of the visible bias (only if complex_weights).

Type:

nn.Parameter or None

b_real

Real part of the hidden bias, shape (num_hidden,).

Type:

nn.Parameter

b_imag

Imaginary part of the hidden bias (only if complex_weights).

Type:

nn.Parameter or None

W_real

Real part of the weight matrix, shape (num_hidden, num_sites).

Type:

nn.Parameter

W_imag

Imaginary part of the weight matrix (only if complex_weights).

Type:

nn.Parameter or None

Examples

>>> rbm = RBMQuantumState(num_sites=10, num_hidden=20)
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = rbm.log_amplitude(x)  # shape (8,)
log_amplitude(x)[source]

Compute the log-amplitude of the RBM wavefunction.

For real weights:

\[\ln|\psi(\mathbf{x})| = \mathrm{Re}(\mathbf{a}) \cdot \mathbf{x} + \sum_i \ln\cosh(\theta_i)\]

For complex weights, takes the real part of the full log-wavefunction.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute the wavefunction phase.

For real weights the phase is identically zero. For complex weights, the phase is the imaginary part of the full log-wavefunction:

\[\arg\psi(\mathbf{x}) = \mathrm{Im}(\mathbf{a}) \cdot \mathbf{x} + \sum_i \mathrm{Im}\bigl(\ln\cosh(\theta_i)\bigr)\]
Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,).

Return type:

torch.Tensor

Autoregressive Transformer

class qvartools.nqs.transformer.autoregressive.AutoregressiveTransformer(n_orbitals, n_alpha, n_beta, embed_dim=64, n_heads=4, n_layers=4, dropout=0.0)[source]

Bases: Module

Autoregressive transformer NQS with alpha/beta spin channels.

Models the joint probability of occupying orbitals by factorising it autoregressively:

\[p(\mathbf{x}) = \prod_{i=1}^{N_{\text{orb}}} p(x^\alpha_i | x^\alpha_{<i}) \;\prod_{i=1}^{N_{\text{orb}}} p(x^\beta_i | x^\beta_{<i}, \mathbf{x}^\alpha)\]

The alpha channel uses causal self-attention only. The beta channel uses causal self-attention plus cross-attention to the full alpha representation, enabling spin-spin correlations.

Sampling enforces particle conservation: exactly n_alpha electrons in the alpha channel and n_beta in the beta channel.

Parameters:
  • n_orbitals (int) – Number of spatial orbitals per spin channel.

  • n_alpha (int) – Number of alpha electrons.

  • n_beta (int) – Number of beta electrons.

  • embed_dim (int, optional) – Embedding dimensionality (default 64).

  • n_heads (int, optional) – Number of attention heads (default 4).

  • n_layers (int, optional) – Number of transformer layers per channel (default 4).

  • dropout (float, optional) – Dropout probability (default 0.0).

alpha_blocks

Transformer blocks for the alpha channel (self-attention only).

Type:

nn.ModuleList

beta_blocks

Transformer blocks for the beta channel (self + cross-attention).

Type:

nn.ModuleList

alpha_head

Output head producing alpha occupation logits.

Type:

nn.Linear

beta_head

Output head producing beta occupation logits.

Type:

nn.Linear

Examples

>>> model = AutoregressiveTransformer(
...     n_orbitals=6, n_alpha=2, n_beta=2, embed_dim=32, n_heads=4
... )
>>> configs = model.sample(n_samples=16)  # shape (16, 12)
log_prob(alpha, beta)[source]

Compute the log-probability of a configuration.

Parameters:
  • alpha (torch.Tensor) – Alpha spin-orbital occupations, shape (batch, n_orbitals) with entries in {0, 1}.

  • beta (torch.Tensor) – Beta spin-orbital occupations, shape (batch, n_orbitals) with entries in {0, 1}.

Returns:

Log-probabilities, shape (batch,).

Return type:

torch.Tensor

sample(n_samples, temperature=1.0)[source]

Generate particle-conserving configurations autoregressively.

Samples alpha orbitals first (enforcing exactly n_alpha electrons), then samples beta orbitals with cross-attention to alpha (enforcing exactly n_beta electrons). The returned configuration is [alpha, beta] concatenated along the orbital axis.

KV caching is used for efficient autoregressive generation.

Parameters:
  • n_samples (int) – Number of configurations to generate.

  • temperature (float, optional) – Sampling temperature. Values > 1 increase randomness; values < 1 sharpen the distribution (default 1.0).

Returns:

Sampled configurations, shape (n_samples, 2 * n_orbitals) with entries in {0, 1}. The first n_orbitals columns are alpha occupations and the last n_orbitals are beta.

Return type:

torch.Tensor

forward(alpha, beta)[source]

Forward pass — delegates to log_prob().

Parameters:
  • alpha (torch.Tensor) – Alpha occupations, shape (batch, n_orbitals).

  • beta (torch.Tensor) – Beta occupations, shape (batch, n_orbitals).

Returns:

Log-probabilities, shape (batch,).

Return type:

torch.Tensor

Utilities

qvartools.nqs.compile_nqs(model, mode='reduce-overhead')[source]

Apply torch.compile to an NQS model with graceful fallback.

Parameters:
  • model (nn.Module) – The neural quantum state model to compile.

  • mode (str, optional) – Compilation mode passed to torch.compile. Common choices are "reduce-overhead" (default) and "max-autotune".

Returns:

The compiled model, or the original model unchanged if compilation fails (e.g. unsupported platform or PyTorch version).

Return type:

nn.Module

Examples

>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32])
>>> nqs = compile_nqs(nqs, mode="max-autotune")