Normalizing Flows

The flows subpackage provides normalizing-flow models for configuration sampling and physics-guided training.

Flow Networks

class qvartools.flows.networks.discrete_flow.DiscreteFlowSampler(num_sites, num_coupling_layers=6, hidden_dims=None, prior_std=1.0, n_mc_samples=100)[source]

Bases: Module

RealNVP normalizing flow mapping continuous latent to discrete configs.

Uses alternating binary masks to split dimensions across coupling layers. The multi-modal prior ensures uniform coverage of both {0, 1} values. Continuous outputs are clamped to [-1, 1] and discretised by thresholding at zero.

Parameters:
  • num_sites (int) – Number of binary sites in each configuration.

  • num_coupling_layers (int, optional) – Number of affine coupling layers (default 6).

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the coupling networks (default [128, 128]).

  • prior_std (float, optional) – Standard deviation of the mixture-of-Gaussians prior components (default 1.0).

  • n_mc_samples (int, optional) – Number of Monte Carlo samples for discrete log-probability estimation (default 100).

num_sites

Number of sites.

Type:

int

num_coupling_layers

Number of coupling layers.

Type:

int

prior

The mixture-of-Gaussians prior.

Type:

MultiModalPrior

masks

Binary masks for each coupling layer.

Type:

list of torch.Tensor

coupling_nets

Coupling networks for each layer.

Type:

nn.ModuleList

n_mc_samples

Number of MC samples for discrete probability estimation.

Type:

int

Examples

>>> flow = DiscreteFlowSampler(num_sites=10, num_coupling_layers=4)
>>> configs, unique = flow.sample(batch_size=256)
>>> configs.shape
torch.Size([256, 10])

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

discretize(y)

Discretise continuous outputs to binary {0, 1} by thresholding.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(batch_size)

Forward pass: sample and compute log-probabilities.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log_prob_continuous(y)

Compute log-probability in continuous space via change of variables.

log_prob_discrete(x)

Estimate discrete log-probability via Monte Carlo integration.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

sample(batch_size)

Sample discrete binary configurations.

sample_continuous(batch_size)

Sample continuous outputs from the flow, clamped to [-1, 1].

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

sample_continuous(batch_size)[source]

Sample continuous outputs from the flow, clamped to [-1, 1].

Parameters:

batch_size (int) – Number of samples to draw.

Returns:

Continuous samples clamped to [-1, 1], shape (batch_size, num_sites).

Return type:

torch.Tensor

static discretize(y)[source]

Discretise continuous outputs to binary {0, 1} by thresholding.

Values at or above zero are mapped to 1; values below zero are mapped to 0.

Parameters:

y (torch.Tensor) – Continuous tensor, shape (..., num_sites).

Returns:

Binary tensor with values in {0, 1}, same shape as y, dtype torch.float32.

Return type:

torch.Tensor

sample(batch_size)[source]

Sample discrete binary configurations.

Draws continuous samples from the flow, discretises them, and returns both the full batch and the unique configurations.

Parameters:

batch_size (int) – Number of samples to draw.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • configs (torch.Tensor) – All discrete configurations, shape (batch_size, num_sites).

  • unique_configs (torch.Tensor) – Unique configurations, shape (n_unique, num_sites) where n_unique <= batch_size.

log_prob_continuous(y)[source]

Compute log-probability in continuous space via change of variables.

Uses the inverse flow to map data-space samples back to the prior, then applies the change-of-variables formula: log p(y) = log p_prior(z) + log |det J^{-1}|.

Parameters:

y (torch.Tensor) – Continuous data-space samples, shape (batch, num_sites).

Returns:

Log-probabilities, shape (batch,).

Return type:

torch.Tensor

log_prob_discrete(x)[source]

Estimate discrete log-probability via Monte Carlo integration.

For each discrete configuration x in {0, 1}^n, the probability is estimated by integrating the continuous density over the corresponding Voronoi cell ([-1, 0) for 0 and [0, 1] for 1). This is done by sampling uniform noise within each cell and averaging the continuous density.

Parameters:

x (torch.Tensor) – Discrete configurations with values in {0, 1}, shape (batch, num_sites).

Returns:

Estimated log-probabilities, shape (batch,).

Return type:

torch.Tensor

forward(batch_size)[source]

Forward pass: sample and compute log-probabilities.

Parameters:

batch_size (int) – Number of samples to draw.

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

  • configs (torch.Tensor) – Discrete configurations, shape (batch_size, num_sites).

  • unique_configs (torch.Tensor) – Unique configurations, shape (n_unique, num_sites).

  • log_probs (torch.Tensor) – Continuous log-probabilities at the pre-discretisation points, shape (batch_size,).

class qvartools.flows.networks.particle_conserving_flow.ParticleConservingFlowSampler(num_sites, n_alpha, n_beta, hidden_dims=None, temperature=1.0, min_temperature=0.01)[source]

Bases: Module

Normalizing flow that exactly conserves alpha and beta particle numbers.

Produces binary configurations of shape (num_sites,) where the first num_sites // 2 entries are alpha orbitals and the remaining are beta orbitals. Exactly n_alpha alpha and n_beta beta orbitals are occupied in every sample.

The flow works by:

  1. Scoring alpha orbitals with a learned network.

  2. Selecting the top n_alpha via differentiable top-k.

  3. Scoring beta orbitals conditioned on the alpha configuration.

  4. Selecting the top n_beta via differentiable top-k.

  5. Concatenating [alpha, beta] to form the full configuration.

Parameters:
  • num_sites (int) – Total number of spin-orbitals (must be even).

  • n_alpha (int) – Number of alpha electrons.

  • n_beta (int) – Number of beta electrons.

  • hidden_dims (list of int, optional) – Hidden-layer sizes for the scoring networks (default [128, 64]).

  • temperature (float, optional) – Initial temperature for differentiable top-k (default 1.0).

  • min_temperature (float, optional) – Minimum temperature (default 0.01).

num_sites

Total number of spin-orbitals.

Type:

int

n_orbitals

Number of spatial orbitals (num_sites // 2).

Type:

int

n_alpha

Number of alpha electrons.

Type:

int

n_beta

Number of beta electrons.

Type:

int

temperature

Current temperature for top-k selection.

Type:

float

alpha_scorer

Scoring network for alpha orbitals.

Type:

OrbitalScoringNetwork

beta_scorer

Scoring network for beta orbitals (conditioned on alpha config).

Type:

OrbitalScoringNetwork

selector

Differentiable top-k selector.

Type:

GumbelTopK

Examples

>>> flow = ParticleConservingFlowSampler(
...     num_sites=10, n_alpha=2, n_beta=2
... )
>>> configs, unique = flow.sample(batch_size=100)
>>> is_valid, stats = verify_particle_conservation(
...     configs, n_orbitals=5, n_alpha=2, n_beta=2
... )
>>> assert is_valid

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(batch_size[, temperature])

Forward pass --- delegates to sample().

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

sample(batch_size[, temperature])

Sample particle-conserving binary configurations.

sample_without_replacement(batch_size[, ...])

Sample unique configurations using deterministic ordering.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

set_temperature(temperature)

Set the temperature for differentiable top-k selection.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

set_temperature(temperature)[source]

Set the temperature for differentiable top-k selection.

Parameters:

temperature (float) – New temperature value. Will be clamped to at least min_temperature.

Return type:

None

sample(batch_size, temperature=None)[source]

Sample particle-conserving binary configurations.

Each configuration has exactly n_alpha occupied alpha orbitals and n_beta occupied beta orbitals.

Parameters:
  • batch_size (int) – Number of configurations to sample.

  • temperature (float or None, optional) – Override temperature for this call. If None, uses the current instance temperature.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • all_configs (torch.Tensor) – All sampled configurations, shape (batch_size, num_sites). The first n_orbitals entries are alpha, the remaining are beta.

  • unique_configs (torch.Tensor) – Unique configurations, shape (n_unique, num_sites).

sample_without_replacement(batch_size, temperature=None)[source]

Sample unique configurations using deterministic ordering.

Generates a larger pool of samples and returns the unique configurations sorted by their logit scores (most probable first).

Parameters:
  • batch_size (int) – Desired number of unique configurations.

  • temperature (float or None, optional) – Override temperature. If None, uses instance temperature.

Returns:

Unique configurations, shape (n_unique, num_sites) where n_unique <= batch_size. Sorted by descending score.

Return type:

torch.Tensor

forward(batch_size, temperature=None)[source]

Forward pass — delegates to sample().

Parameters:
  • batch_size (int) – Number of configurations to sample.

  • temperature (float or None, optional) – Override temperature.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • all_configs (torch.Tensor) – All configurations, shape (batch_size, num_sites).

  • unique_configs (torch.Tensor) – Unique configurations, shape (n_unique, num_sites).

Training

class qvartools.flows.training.physics_guided_training.PhysicsGuidedFlowTrainer(flow, nqs, hamiltonian, config, device='cpu')[source]

Bases: object

Mixed-objective trainer for joint flow + NQS optimisation.

Combines three loss terms with configurable weights:

  • Teacher loss: -sum_x p_nqs(x) * log p_flow(x) — trains the flow to reproduce the NQS distribution.

  • Physics loss: variational energy E = sum_x |psi(x)|^2 * E_loc(x) / Z where E_loc(x) = sum_{x'} H_{x,x'} * psi(x') / psi(x) — minimises the ground-state energy estimate.

  • Entropy loss: -H[p_flow] (negative entropy) — prevents mode collapse by encouraging distribution spread.

The trainer also:

  • Accumulates unique configurations into a growing basis set.

  • Anneals the flow temperature over early epochs.

  • Optionally injects essential (HF + singles + doubles) configurations into the basis.

  • Caches Hamiltonian connections for efficiency.

Parameters:
  • flow (nn.Module) – The normalizing flow sampler. Must implement sample(batch_size) returning (all_configs, unique_configs).

  • nqs (nn.Module) – The neural quantum state. Must implement log_amplitude(x) returning log-amplitudes of shape (batch,).

  • hamiltonian (Hamiltonian) – The Hamiltonian operator.

  • config (PhysicsGuidedConfig) – Training hyperparameters.

  • device (str, optional) – Torch device override (default uses config.device).

flow

The flow model.

Type:

nn.Module

nqs

The NQS model.

Type:

nn.Module

hamiltonian

The Hamiltonian.

Type:

Hamiltonian

config

Training configuration.

Type:

PhysicsGuidedConfig

device

Active device.

Type:

torch.device

accumulated_basis

Growing set of unique configurations seen during training.

Type:

torch.Tensor or None

flow_optimizer

Optimiser for the flow parameters.

Type:

torch.optim.Adam

nqs_optimizer

Optimiser for the NQS parameters.

Type:

torch.optim.Adam

energy_baseline

Running EMA baseline for variance reduction.

Type:

float

connection_cache

Cache for Hamiltonian connections.

Type:

ConnectionCache or None

Methods

train([progress])

Run the full training loop.

train(progress=True)[source]

Run the full training loop.

Trains for up to config.num_epochs epochs, with early stopping when the unique-configuration ratio converges (change less than config.convergence_threshold for two consecutive epochs after config.min_epochs).

Parameters:

progress (bool, optional) – If True, log epoch-level metrics at INFO level (default True).

Returns:

Training history with keys matching the epoch metrics, each mapping to a list of per-epoch values:

  • "teacher_loss" : list of float

  • "physics_loss" : list of float

  • "entropy_loss" : list of float

  • "total_loss" : list of float

  • "mean_energy" : list of float

  • "unique_ratio" : list of float

  • "basis_size" : list of int

  • "temperature" : list of float

Return type:

dict

class qvartools.flows.training.physics_guided_training.PhysicsGuidedConfig(samples_per_batch=500, num_batches=10, num_epochs=200, min_epochs=50, convergence_threshold=0.01, flow_lr=0.001, nqs_lr=0.001, teacher_weight=1.0, physics_weight=0.0, entropy_weight=0.0, use_energy_baseline=True, ema_decay=0.99, use_connection_cache=True, max_cache_size=100000, initial_temperature=2.0, final_temperature=0.1, temperature_decay_epochs=100, inject_essential_configs=True, include_singles_in_basis=True, include_doubles_in_basis=True, device='cpu')[source]

Bases: object

Hyperparameters for PhysicsGuidedFlowTrainer.

All fields have sensible defaults for molecular-scale problems. The class is frozen (immutable) to prevent accidental mutation during training.

Parameters:
  • samples_per_batch (int) – Number of flow samples per mini-batch (default 500).

  • num_batches (int) – Number of mini-batches per epoch (default 10).

  • num_epochs (int) – Maximum number of training epochs (default 200).

  • min_epochs (int) – Minimum epochs before convergence checks activate (default 50).

  • convergence_threshold (float) – Training stops when the unique-configuration ratio changes by less than this amount over consecutive epochs (default 0.01).

  • flow_lr (float) – Learning rate for the flow optimiser (default 1e-3).

  • nqs_lr (float) – Learning rate for the NQS optimiser (default 1e-3).

  • teacher_weight (float) – Weight of the teacher (KL) loss (default 1.0).

  • physics_weight (float) – Weight of the variational energy loss (default 0.0).

  • entropy_weight (float) – Weight of the entropy regularisation loss (default 0.0).

  • use_energy_baseline (bool) – Whether to subtract a running baseline from the energy for variance reduction (default True).

  • ema_decay (float) – Exponential moving average decay for the energy baseline (default 0.99).

  • use_connection_cache (bool) – Whether to cache Hamiltonian connections for repeated configs (default True).

  • max_cache_size (int) – Maximum number of entries in the connection cache (default 100000).

  • initial_temperature (float) – Starting temperature for flow annealing (default 2.0).

  • final_temperature (float) – Final temperature after annealing (default 0.1).

  • temperature_decay_epochs (int) – Number of epochs over which to anneal temperature (default 100).

  • inject_essential_configs (bool) – Whether to inject Hartree–Fock and nearby configurations into the basis (default True).

  • include_singles_in_basis (bool) – Whether to include single excitations in the essential basis (default True).

  • include_doubles_in_basis (bool) – Whether to include double excitations in the essential basis (default True).

  • device (str) – Torch device for training (default "cpu").

samples_per_batch: int = 500
num_batches: int = 10
num_epochs: int = 200
min_epochs: int = 50
convergence_threshold: float = 0.01
flow_lr: float = 0.001
nqs_lr: float = 0.001
teacher_weight: float = 1.0
physics_weight: float = 0.0
entropy_weight: float = 0.0
use_energy_baseline: bool = True
ema_decay: float = 0.99
use_connection_cache: bool = True
max_cache_size: int = 100000
initial_temperature: float = 2.0
final_temperature: float = 0.1
temperature_decay_epochs: int = 100
inject_essential_configs: bool = True
include_singles_in_basis: bool = True
include_doubles_in_basis: bool = True
device: str = 'cpu'

Loss Functions

loss_functions — Loss computation for physics-guided flow training

Standalone loss functions and supporting utilities extracted from the physics-guided training loop:

class qvartools.flows.training.loss_functions.ConnectionCache(max_size=100000)[source]

Bases: object

LRU-style cache for Hamiltonian connection lookups.

Stores the result of hamiltonian.get_connections(config) keyed by the configuration as a tuple of integers. Evicts the oldest entries when max_size is exceeded.

Parameters:

max_size (int) – Maximum number of cached entries (default 100000).

max_size

Capacity limit.

Type:

int

Methods

clear()

Remove all cached entries.

get(config, hamiltonian)

Retrieve connections for config, computing and caching if absent.

get(config, hamiltonian)[source]

Retrieve connections for config, computing and caching if absent.

Parameters:
  • config (torch.Tensor) – Single configuration, shape (num_sites,).

  • hamiltonian (Hamiltonian) – The Hamiltonian to query for connections.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • connected_configs (torch.Tensor) – Connected configurations, shape (n_conn, num_sites).

  • matrix_elements (torch.Tensor) – Corresponding matrix elements, shape (n_conn,).

clear()[source]

Remove all cached entries.

Return type:

None

qvartools.flows.training.loss_functions.compute_teacher_loss(configs, log_probs_flow, nqs)[source]

Compute the teacher (KL divergence) loss.

L_teacher = -sum_x p_nqs(x) * log p_flow(x)

The NQS probabilities are detached (treated as fixed targets).

Parameters:
  • configs (torch.Tensor) – Sampled configurations, shape (batch, num_sites).

  • log_probs_flow (torch.Tensor) – Flow log-probabilities, shape (batch,).

  • nqs (nn.Module) – Neural quantum state with a log_amplitude(x) method.

Returns:

Scalar teacher loss.

Return type:

torch.Tensor

qvartools.flows.training.loss_functions.compute_physics_loss(configs, nqs, hamiltonian, device, energy_baseline, baseline_initialized, use_energy_baseline, ema_decay, connection_cache=None)[source]

Compute the variational energy (physics) loss.

L_physics = sum_x |psi(x)|^2 * E_loc(x) / Z

Uses a running EMA baseline for variance reduction when enabled.

Parameters:
  • configs (torch.Tensor) – Sampled configurations, shape (batch, num_sites).

  • nqs (nn.Module) – Neural quantum state with a log_amplitude(x) method.

  • hamiltonian (Hamiltonian) – The Hamiltonian operator.

  • device (torch.device) – Torch device for computation.

  • energy_baseline (float) – Current EMA energy baseline value.

  • baseline_initialized (bool) – Whether the baseline has been initialised.

  • use_energy_baseline (bool) – Whether to apply variance reduction via EMA baseline.

  • ema_decay (float) – Exponential moving average decay for the baseline.

  • connection_cache (ConnectionCache or None, optional) – Optional cache for Hamiltonian connections.

Return type:

Tuple[Tensor, float, float, bool]

Returns:

  • loss (torch.Tensor) – Scalar physics loss.

  • mean_energy (float) – Mean local energy (for logging).

  • updated_baseline (float) – Updated EMA energy baseline.

  • updated_initialized (bool) – Whether the baseline is now initialised.

qvartools.flows.training.loss_functions.compute_entropy_loss(log_probs_flow)[source]

Compute the negative entropy of the flow distribution.

L_entropy = sum_x p_flow(x) * log p_flow(x) = -H[p_flow]

Minimising this loss maximises the entropy.

Parameters:

log_probs_flow (torch.Tensor) – Flow log-probabilities, shape (batch,).

Returns:

Scalar entropy loss (negative entropy).

Return type:

torch.Tensor

qvartools.flows.training.loss_functions.compute_local_energy(configs, nqs, hamiltonian, device, connection_cache=None)[source]

Compute the local energy E_loc(x) for each configuration.

E_loc(x) = H_{x,x} + sum_{x' != x} H_{x,x'} * psi(x') / psi(x)

Optimised to minimise CPU-GPU transfers and batch all NQS evaluations into a single call.

Parameters:
  • configs (torch.Tensor) – Configurations, shape (batch, num_sites).

  • nqs (nn.Module) – Neural quantum state with a log_amplitude(x) method.

  • hamiltonian (Hamiltonian) – The Hamiltonian operator.

  • device (torch.device) – Torch device for computation.

  • connection_cache (ConnectionCache or None, optional) – Optional cache for Hamiltonian connections.

Returns:

Local energies, shape (batch,).

Return type:

torch.Tensor

Gumbel Top-k

gumbel_topk — Differentiable top-k selection mechanisms

Provides differentiable approximations to top-k selection for use in particle-number-conserving normalizing flows:

  • GumbelTopK — Gumbel-Softmax-based iterative selection.

  • SigmoidTopK — Sigmoid thresholding with implicit binary search.

class qvartools.flows.training.gumbel_topk.GumbelTopK(temperature=1.0, min_temperature=0.01)[source]

Bases: Module

Gumbel-Softmax-based differentiable top-k selection.

Adds Gumbel noise to logits and applies a softmax to produce a soft approximation of top-k selection. At low temperatures the selection approaches a hard top-k; at high temperatures it is fully stochastic.

Parameters:
  • temperature (float, optional) – Initial temperature for the Gumbel-Softmax (default 1.0).

  • min_temperature (float, optional) – Minimum temperature to prevent numerical issues (default 0.01).

temperature

Current temperature.

Type:

float

min_temperature

Lower bound on temperature.

Type:

float

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(logits, k[, temperature])

Select k elements via Gumbel-Softmax relaxation.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

forward(logits, k, temperature=None)[source]

Select k elements via Gumbel-Softmax relaxation.

Parameters:
  • logits (torch.Tensor) – Unnormalised scores, shape (batch, n).

  • k (int) – Number of elements to select.

  • temperature (float or None, optional) – Override temperature for this call. If None, uses the instance temperature.

Returns:

Soft selection mask, shape (batch, n). Values are in [0, 1] and approximately sum to k per row.

Return type:

torch.Tensor

class qvartools.flows.training.gumbel_topk.SigmoidTopK(temperature=1.0, min_temperature=0.01)[source]

Bases: Module

Sigmoid-based differentiable top-k selection with implicit threshold.

Uses a learned or computed threshold to produce per-element sigmoid activations, then normalises to select exactly k elements in expectation.

Parameters:
  • temperature (float, optional) – Initial temperature controlling sigmoid sharpness (default 1.0).

  • min_temperature (float, optional) – Minimum temperature (default 0.01).

temperature

Current temperature.

Type:

float

min_temperature

Lower bound on temperature.

Type:

float

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(logits, k[, temperature])

Select k elements via sigmoid thresholding.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

forward(logits, k, temperature=None)[source]

Select k elements via sigmoid thresholding.

Finds a threshold such that sum(sigmoid((logits - threshold) / T)) is approximately k, then returns the sigmoid activations.

Parameters:
  • logits (torch.Tensor) – Unnormalised scores, shape (batch, n).

  • k (int) – Number of elements to select.

  • temperature (float or None, optional) – Override temperature. If None, uses instance temperature.

Returns:

Soft selection mask, shape (batch, n). Values are in [0, 1] and approximately sum to k per row.

Return type:

torch.Tensor

Utilities

qvartools.flows.networks.particle_conserving_flow.verify_particle_conservation(configs, n_orbitals, n_alpha, n_beta)[source]

Validate that all configurations conserve particle numbers.

Checks that each configuration has exactly n_alpha occupied alpha orbitals (first n_orbitals sites) and n_beta occupied beta orbitals (remaining n_orbitals sites).

Parameters:
  • configs (torch.Tensor) – Binary configurations, shape (n_configs, 2 * n_orbitals).

  • n_orbitals (int) – Number of spatial orbitals (half of num_sites).

  • n_alpha (int) – Expected number of alpha electrons per configuration.

  • n_beta (int) – Expected number of beta electrons per configuration.

Return type:

Tuple[bool, Dict[str, object]]

Returns:

  • is_valid (bool) – True if every configuration has exactly the correct particle numbers.

  • stats (dict) – Dictionary with detailed statistics:

    • "n_configs" : int — total number of configurations.

    • "n_valid" : int — number of valid configurations.

    • "n_invalid" : int — number of invalid configurations.

    • "alpha_counts" : torch.Tensor — alpha electron count per config.

    • "beta_counts" : torch.Tensor — beta electron count per config.

    • "alpha_violations" : int — configs with wrong alpha count.

    • "beta_violations" : int — configs with wrong beta count.

Examples

>>> configs = torch.tensor([[1, 1, 0, 1, 0, 1]])  # 2 alpha, 2 beta
>>> is_valid, stats = verify_particle_conservation(configs, 3, 2, 2)
>>> is_valid
True