Neural Quantum States

The nqs subpackage provides neural quantum state architectures that parameterize the many-body wavefunction as a neural network.

Base Class

class qvartools.nqs.neural_state.NeuralQuantumState(num_sites, local_dim=2, complex_output=False)[source]

Bases: Module, ABC

Abstract base class for all neural quantum state ansaetze.

Every subclass must implement log_amplitude() and phase(). The base class provides convenience methods for evaluating the full log-wavefunction, the wavefunction itself, Born-rule probabilities, and normalised probabilities over a discrete basis set.

Parameters:
  • num_sites (int) – Number of lattice / orbital sites (input dimensionality).

  • local_dim (int, optional) – Dimension of the local Hilbert space on each site (default 2 for spin-1/2 / qubit systems).

  • complex_output (bool, optional) – If True the NQS represents a complex-valued wavefunction with a non-trivial phase network. If False the phase is identically zero and log_psi() returns a single real tensor.

num_sites

Number of sites.

Type:

int

local_dim

Local Hilbert-space dimension.

Type:

int

complex_output

Whether the NQS has a non-trivial phase.

Type:

bool

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

encode_configuration(config)

Convert a configuration tensor to float for network input.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Forward pass --- delegates to log_psi().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_amplitude(x)

Compute the log-amplitude ln|psi(x)| for each configuration.

log_psi(x)

Compute the log-wavefunction.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

normalized_probability(x, basis_set)

Compute normalised Born-rule probabilities over a basis set.

parameters([recurse])

Return an iterator over module parameters.

phase(x)

Compute the phase arg(psi(x)) for each configuration.

probability(x)

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

psi(x)

Evaluate the full wavefunction psi(x).

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

abstractmethod log_amplitude(x)[source]

Compute the log-amplitude ln|psi(x)| for each configuration.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

abstractmethod phase(x)[source]

Compute the phase arg(psi(x)) for each configuration.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,). Must be identically zero when complex_output is False.

Return type:

torch.Tensor

log_psi(x)[source]

Compute the log-wavefunction.

For real-valued NQS (complex_output is False), returns only the log-amplitude. For complex-valued NQS, returns a tuple of (log_amplitude, phase).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

If complex_output is False: log-amplitude tensor of shape (batch,). If complex_output is True: tuple (log_amp, phase) each of shape (batch,).

Return type:

torch.Tensor or tuple of torch.Tensor

psi(x)[source]

Evaluate the full wavefunction psi(x).

Computes exp(log_amp) * exp(i * phase). When complex_output is False the result is real-valued (dtype matches input); otherwise it is complex.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Wavefunction values, shape (batch,). Complex dtype when complex_output is True.

Return type:

torch.Tensor

probability(x)[source]

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Unnormalised probabilities exp(2 * log_amplitude(x)), shape (batch,).

Return type:

torch.Tensor

normalized_probability(x, basis_set)[source]

Compute normalised Born-rule probabilities over a basis set.

The normalisation constant Z is computed as the sum of |psi(s)|^2 over every configuration s in basis_set.

Parameters:
  • x (torch.Tensor) – Configurations to evaluate, shape (batch, num_sites).

  • basis_set (torch.Tensor) – Complete (or reference) set of configurations used to compute the partition function, shape (n_basis, num_sites).

Returns:

Normalised probabilities, shape (batch,). Each entry is |psi(x_i)|^2 / Z.

Return type:

torch.Tensor

static encode_configuration(config)[source]

Convert a configuration tensor to float for network input.

Parameters:

config (torch.Tensor) – Configuration tensor of any integer or float dtype, shape (..., num_sites).

Returns:

Float tensor with the same shape, dtype torch.float32.

Return type:

torch.Tensor

forward(x)[source]

Forward pass — delegates to log_psi().

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Same as log_psi().

Return type:

torch.Tensor or tuple of torch.Tensor

Dense Architectures

class qvartools.nqs.architectures.dense.DenseNQS(num_sites, hidden_dims=None, complex_output=False)[source]

Bases: NeuralQuantumState

Fully connected feedforward neural quantum state.

The amplitude network maps a configuration vector to a scalar log-amplitude via a stack of Linear + ReLU layers, followed by a final Linear + Tanh layer whose output is scaled by a learnable log_amp_scale parameter.

If complex_output is True, a separate phase network of the same depth produces the wavefunction phase in (-pi, pi).

Parameters:
  • num_sites (int) – Number of lattice / orbital sites.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the amplitude (and phase) networks (default [128, 64]).

  • complex_output (bool, optional) – Whether to include a phase network (default False).

amplitude_net

The amplitude MLP (output before scaling).

Type:

nn.Sequential

log_amp_scale

Learnable scalar that multiplies the Tanh output.

Type:

nn.Parameter

phase_net

Phase MLP when complex_output is True, else None.

Type:

nn.Sequential or None

Examples

>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32])
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x)  # shape (8,)

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

encode_configuration(config)

Convert a configuration tensor to float for network input.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Forward pass --- delegates to log_psi().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_amplitude(x)

Compute log-amplitude ln|psi(x)|.

log_psi(x)

Compute the log-wavefunction.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

normalized_probability(x, basis_set)

Compute normalised Born-rule probabilities over a basis set.

parameters([recurse])

Return an iterator over module parameters.

phase(x)

Compute the wavefunction phase.

probability(x)

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

psi(x)

Evaluate the full wavefunction psi(x).

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

log_amplitude(x)[source]

Compute log-amplitude ln|psi(x)|.

The raw amplitude network output (bounded in (-1, 1) by Tanh) is multiplied by the learnable log_amp_scale.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute the wavefunction phase.

Returns zeros for real-valued NQS. For complex NQS the phase network output (in (-1, 1)) is scaled by pi so the phase lies in (-pi, pi).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,).

Return type:

torch.Tensor

class qvartools.nqs.architectures.dense.SignedDenseNQS(num_sites, hidden_dims=None)[source]

Bases: NeuralQuantumState

Dense NQS with explicit sign structure.

Uses a shared feature extractor whose output feeds into two heads:

  • Amplitude head — produces the log-amplitude via Softplus to ensure non-negative output.

  • Sign head — produces a logit whose sigmoid is thresholded at 0.5 to yield a phase of either 0 (positive) or pi (negative).

Feature caching avoids redundant computation when log_amplitude() and phase() are called on the same input within one evaluation.

Parameters:
  • num_sites (int) – Number of lattice / orbital sites.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the shared feature extractor (default [128, 64]).

feature_net

Shared feature extractor.

Type:

nn.Sequential

amplitude_head

Maps features to log-amplitude (Softplus output).

Type:

nn.Sequential

sign_head

Maps features to a sign logit.

Type:

nn.Linear

Examples

>>> nqs = SignedDenseNQS(num_sites=10)
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x)  # shape (8,)
>>> phi = nqs.phase(x)              # shape (8,), values in {0, pi}

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

clear_feature_cache()

Clear the feature cache.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

encode_configuration(config)

Convert a configuration tensor to float for network input.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Forward pass --- delegates to log_psi().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_amplitude(x)

Compute log-amplitude from the amplitude head.

log_psi(x)

Compute the log-wavefunction.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

normalized_probability(x, basis_set)

Compute normalised Born-rule probabilities over a basis set.

parameters([recurse])

Return an iterator over module parameters.

phase(x)

Compute the sign-derived phase.

probability(x)

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

psi(x)

Evaluate the full wavefunction psi(x).

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

clear_feature_cache()[source]

Clear the feature cache.

Call this between training steps or whenever the input batch changes to avoid stale cached values.

Return type:

None

log_amplitude(x)[source]

Compute log-amplitude from the amplitude head.

The Softplus activation ensures the raw amplitude is non-negative; the log-amplitude is the logarithm of that value.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute the sign-derived phase.

The sign logit is passed through a sigmoid. Values above 0.5 correspond to a positive sign (phase = 0); values below 0.5 correspond to a negative sign (phase = pi).

During training, a soft interpolation is used for gradient flow: phase = pi * (1 - sigmoid(sign_logit)).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phase values, shape (batch,). During training these are continuous in (0, pi); at eval they snap to {0, pi}.

Return type:

torch.Tensor

class qvartools.nqs.architectures.complex_nqs.ComplexNQS(num_sites, hidden_dims=None)[source]

Bases: NeuralQuantumState

Complex-valued NQS with shared feature extractor.

A shared MLP backbone produces feature vectors that are fed into two independent heads:

  • Amplitude head — maps features to a scalar log-amplitude.

  • Phase head — maps features to a phase in (-pi, pi).

Feature caching avoids redundant computation when log_amplitude() and phase() are called sequentially on the same input tensor.

Parameters:
  • num_sites (int) – Number of lattice / orbital sites.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the shared feature extractor (default [128, 64]).

feature_net

Shared feature MLP.

Type:

nn.Sequential

amplitude_head

Linear projection from features to scalar log-amplitude.

Type:

nn.Linear

phase_head

Maps features to phase via Linear + Tanh (scaled by pi).

Type:

nn.Sequential

Examples

>>> nqs = ComplexNQS(num_sites=10, hidden_dims=[64, 32])
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x)  # shape (8,)
>>> phi = nqs.phase(x)              # shape (8,), in (-pi, pi)

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

clear_feature_cache()

Clear the feature cache.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

encode_configuration(config)

Convert a configuration tensor to float for network input.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Forward pass --- delegates to log_psi().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_amplitude(x)

Compute log-amplitude from shared features.

log_psi(x)

Compute the log-wavefunction.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

normalized_probability(x, basis_set)

Compute normalised Born-rule probabilities over a basis set.

parameters([recurse])

Return an iterator over module parameters.

phase(x)

Compute phase from shared features.

probability(x)

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

psi(x)

Evaluate the full wavefunction psi(x).

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

clear_feature_cache()[source]

Clear the feature cache.

Call this between training steps or when the input batch changes to avoid stale cached values.

Return type:

None

log_amplitude(x)[source]

Compute log-amplitude from shared features.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute phase from shared features.

The Tanh output (in (-1, 1)) is scaled by pi to produce a phase in (-pi, pi).

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,).

Return type:

torch.Tensor

Restricted Boltzmann Machine

class qvartools.nqs.architectures.rbm.RBMQuantumState(num_sites, num_hidden=None, complex_weights=False)[source]

Bases: NeuralQuantumState

Restricted Boltzmann Machine neural quantum state.

Implements the RBM ansatz of Carleo & Troyer (Science, 2017):

\[\psi(\mathbf{x}) = \exp\!\Bigl(\sum_j a_j x_j\Bigr) \prod_i \cosh\!\Bigl(b_i + \sum_j W_{ij} x_j\Bigr)\]

When complex_weights is True, the parameters a, b, and W are complex-valued, and the wavefunction acquires a non-trivial phase.

Parameters:
  • num_sites (int) – Number of visible units (lattice / orbital sites).

  • num_hidden (int, optional) – Number of hidden units (default num_sites).

  • complex_weights (bool, optional) – If True, use complex-valued RBM parameters to represent a complex wavefunction (default False).

a_real

Real part of the visible bias, shape (num_sites,).

Type:

nn.Parameter

a_imag

Imaginary part of the visible bias (only if complex_weights).

Type:

nn.Parameter or None

b_real

Real part of the hidden bias, shape (num_hidden,).

Type:

nn.Parameter

b_imag

Imaginary part of the hidden bias (only if complex_weights).

Type:

nn.Parameter or None

W_real

Real part of the weight matrix, shape (num_hidden, num_sites).

Type:

nn.Parameter

W_imag

Imaginary part of the weight matrix (only if complex_weights).

Type:

nn.Parameter or None

Examples

>>> rbm = RBMQuantumState(num_sites=10, num_hidden=20)
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = rbm.log_amplitude(x)  # shape (8,)

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

encode_configuration(config)

Convert a configuration tensor to float for network input.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Forward pass --- delegates to log_psi().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_amplitude(x)

Compute the log-amplitude of the RBM wavefunction.

log_psi(x)

Compute the log-wavefunction.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

normalized_probability(x, basis_set)

Compute normalised Born-rule probabilities over a basis set.

parameters([recurse])

Return an iterator over module parameters.

phase(x)

Compute the wavefunction phase.

probability(x)

Compute the Born-rule probability |psi(x)|^2 (unnormalised).

psi(x)

Evaluate the full wavefunction psi(x).

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

log_amplitude(x)[source]

Compute the log-amplitude of the RBM wavefunction.

For real weights:

\[\ln|\psi(\mathbf{x})| = \mathrm{Re}(\mathbf{a}) \cdot \mathbf{x} + \sum_i \ln\cosh(\theta_i)\]

For complex weights, takes the real part of the full log-wavefunction.

Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Log-amplitudes, shape (batch,).

Return type:

torch.Tensor

phase(x)[source]

Compute the wavefunction phase.

For real weights the phase is identically zero. For complex weights, the phase is the imaginary part of the full log-wavefunction:

\[\arg\psi(\mathbf{x}) = \mathrm{Im}(\mathbf{a}) \cdot \mathbf{x} + \sum_i \mathrm{Im}\bigl(\ln\cosh(\theta_i)\bigr)\]
Parameters:

x (torch.Tensor) – Batch of configurations, shape (batch, num_sites).

Returns:

Phases in radians, shape (batch,).

Return type:

torch.Tensor

Autoregressive Transformer

class qvartools.nqs.transformer.autoregressive.AutoregressiveTransformer(n_orbitals, n_alpha, n_beta, embed_dim=64, n_heads=4, n_layers=4, dropout=0.0)[source]

Bases: Module

Autoregressive transformer NQS with alpha/beta spin channels.

Models the joint probability of occupying orbitals by factorising it autoregressively:

\[p(\mathbf{x}) = \prod_{i=1}^{N_{\text{orb}}} p(x^\alpha_i | x^\alpha_{<i}) \;\prod_{i=1}^{N_{\text{orb}}} p(x^\beta_i | x^\beta_{<i}, \mathbf{x}^\alpha)\]

The alpha channel uses causal self-attention only. The beta channel uses causal self-attention plus cross-attention to the full alpha representation, enabling spin-spin correlations.

Sampling enforces particle conservation: exactly n_alpha electrons in the alpha channel and n_beta in the beta channel.

Parameters:
  • n_orbitals (int) – Number of spatial orbitals per spin channel.

  • n_alpha (int) – Number of alpha electrons.

  • n_beta (int) – Number of beta electrons.

  • embed_dim (int, optional) – Embedding dimensionality (default 64).

  • n_heads (int, optional) – Number of attention heads (default 4).

  • n_layers (int, optional) – Number of transformer layers per channel (default 4).

  • dropout (float, optional) – Dropout probability (default 0.0).

alpha_blocks

Transformer blocks for the alpha channel (self-attention only).

Type:

nn.ModuleList

beta_blocks

Transformer blocks for the beta channel (self + cross-attention).

Type:

nn.ModuleList

alpha_head

Output head producing alpha occupation logits.

Type:

nn.Linear

beta_head

Output head producing beta occupation logits.

Type:

nn.Linear

Examples

>>> model = AutoregressiveTransformer(
...     n_orbitals=6, n_alpha=2, n_beta=2, embed_dim=32, n_heads=4
... )
>>> configs = model.sample(n_samples=16)  # shape (16, 12)

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(alpha, beta)

Forward pass --- delegates to log_prob().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_prob(alpha, beta)

Compute the log-probability of a configuration.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

sample(n_samples[, temperature])

Generate particle-conserving configurations autoregressively.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

log_prob(alpha, beta)[source]

Compute the log-probability of a configuration.

Parameters:
  • alpha (torch.Tensor) – Alpha spin-orbital occupations, shape (batch, n_orbitals) with entries in {0, 1}.

  • beta (torch.Tensor) – Beta spin-orbital occupations, shape (batch, n_orbitals) with entries in {0, 1}.

Returns:

Log-probabilities, shape (batch,).

Return type:

torch.Tensor

sample(n_samples, temperature=1.0)[source]

Generate particle-conserving configurations autoregressively.

Samples alpha orbitals first (enforcing exactly n_alpha electrons), then samples beta orbitals with cross-attention to alpha (enforcing exactly n_beta electrons). The returned configuration is [alpha, beta] concatenated along the orbital axis.

KV caching is used for efficient autoregressive generation.

Parameters:
  • n_samples (int) – Number of configurations to generate.

  • temperature (float, optional) – Sampling temperature. Values > 1 increase randomness; values < 1 sharpen the distribution (default 1.0).

Returns:

Sampled configurations, shape (n_samples, 2 * n_orbitals) with entries in {0, 1}. The first n_orbitals columns are alpha occupations and the last n_orbitals are beta.

Return type:

torch.Tensor

forward(alpha, beta)[source]

Forward pass — delegates to log_prob().

Parameters:
  • alpha (torch.Tensor) – Alpha occupations, shape (batch, n_orbitals).

  • beta (torch.Tensor) – Beta occupations, shape (batch, n_orbitals).

Returns:

Log-probabilities, shape (batch,).

Return type:

torch.Tensor

Utilities

qvartools.nqs.compile_nqs(model, mode='reduce-overhead')[source]

Apply torch.compile to an NQS model with graceful fallback.

Parameters:
  • model (nn.Module) – The neural quantum state model to compile.

  • mode (str, optional) – Compilation mode passed to torch.compile. Common choices are "reduce-overhead" (default) and "max-autotune".

Returns:

The compiled model, or the original model unchanged if compilation fails (e.g. unsupported platform or PyTorch version).

Return type:

nn.Module

Examples

>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32])
>>> nqs = compile_nqs(nqs, mode="max-autotune")