"""
particle_conserving_flow --- Particle-number-conserving flow sampler
====================================================================
Implements a normalizing-flow sampler that exactly conserves the number
of alpha and beta electrons by construction. Instead of learning an
unconstrained bijection, this module learns a *scoring function* for
each orbital and selects the top-k orbitals via differentiable top-k
mechanisms (Gumbel-Softmax or sigmoid-based).
The result is a set of binary configurations in which exactly
``n_alpha`` alpha orbitals and ``n_beta`` beta orbitals are occupied,
guaranteeing valid Slater determinants for quantum chemistry.
"""
from __future__ import annotations
import logging
import torch
import torch.nn as nn
from qvartools.flows.training.gumbel_topk import GumbelTopK
__all__ = [
"ParticleConservingFlowSampler",
"verify_particle_conservation",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Orbital scoring network
# ---------------------------------------------------------------------------
class OrbitalScoringNetwork(nn.Module):
"""Neural network that scores orbital occupations.
Given an optional context vector (e.g. from the alpha-spin
configuration), produces logits indicating how favourable it is to
occupy each orbital.
Parameters
----------
n_orbitals : int
Number of orbitals to score.
hidden_dims : list of int, optional
Hidden-layer sizes (default ``[128, 64]``).
context_dim : int, optional
Dimensionality of the context vector (default ``32``).
If zero, no context is used.
Attributes
----------
n_orbitals : int
Number of orbitals.
context_dim : int
Context vector size.
net : nn.Sequential
The scoring MLP.
"""
def __init__(
self,
n_orbitals: int,
hidden_dims: list[int] | None = None,
context_dim: int = 32,
) -> None:
super().__init__()
if hidden_dims is None:
hidden_dims = [128, 64]
self.n_orbitals: int = n_orbitals
self.context_dim: int = context_dim
# Input: orbital index embedding + optional context
input_dim = n_orbitals + context_dim
layers: list[nn.Module] = []
prev_dim = input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, h_dim))
layers.append(nn.LeakyReLU(0.01))
prev_dim = h_dim
layers.append(nn.Linear(prev_dim, n_orbitals))
self.net: nn.Sequential = nn.Sequential(*layers)
# Learnable baseline logits for each orbital
self.baseline = nn.Parameter(torch.zeros(n_orbitals))
# Context projection (identity embedding when no external context)
if context_dim > 0:
self.context_proj = nn.Linear(n_orbitals, context_dim)
else:
self.context_proj = None
def forward(
self,
batch_size: int,
context: torch.Tensor | None = None,
) -> torch.Tensor:
"""Produce orbital-occupation logits.
Parameters
----------
batch_size : int
Number of samples in the batch.
context : torch.Tensor or None, optional
Context vector of shape ``(batch_size, n_orbitals)`` (e.g. the
alpha configuration). If ``None`` and ``context_dim > 0``,
a zero context is used.
Returns
-------
torch.Tensor
Logits for each orbital, shape ``(batch_size, n_orbitals)``.
"""
device = self.baseline.device
# Build input: [baseline_expanded, context_proj]
baseline_expanded = self.baseline.unsqueeze(0).expand(
batch_size, -1
) # (batch, n_orbitals)
if self.context_dim > 0:
if context is not None:
ctx = self.context_proj(context.float()) # (batch, context_dim)
else:
ctx = torch.zeros(batch_size, self.context_dim, device=device)
net_input = torch.cat(
[baseline_expanded, ctx], dim=-1
) # (batch, n_orbitals + context_dim)
else:
net_input = baseline_expanded
logits = self.net(net_input) # (batch, n_orbitals)
return logits + baseline_expanded
# ---------------------------------------------------------------------------
# ParticleConservingFlowSampler
# ---------------------------------------------------------------------------
[docs]
class ParticleConservingFlowSampler(nn.Module):
"""Normalizing flow that exactly conserves alpha and beta particle numbers.
Produces binary configurations of shape ``(num_sites,)`` where the
first ``num_sites // 2`` entries are alpha orbitals and the remaining
are beta orbitals. Exactly ``n_alpha`` alpha and ``n_beta`` beta
orbitals are occupied in every sample.
The flow works by:
1. Scoring alpha orbitals with a learned network.
2. Selecting the top ``n_alpha`` via differentiable top-k.
3. Scoring beta orbitals conditioned on the alpha configuration.
4. Selecting the top ``n_beta`` via differentiable top-k.
5. Concatenating ``[alpha, beta]`` to form the full configuration.
Parameters
----------
num_sites : int
Total number of spin-orbitals (must be even).
n_alpha : int
Number of alpha electrons.
n_beta : int
Number of beta electrons.
hidden_dims : list of int, optional
Hidden-layer sizes for the scoring networks (default ``[128, 64]``).
temperature : float, optional
Initial temperature for differentiable top-k (default ``1.0``).
min_temperature : float, optional
Minimum temperature (default ``0.01``).
Attributes
----------
num_sites : int
Total number of spin-orbitals.
n_orbitals : int
Number of spatial orbitals (``num_sites // 2``).
n_alpha : int
Number of alpha electrons.
n_beta : int
Number of beta electrons.
temperature : float
Current temperature for top-k selection.
alpha_scorer : OrbitalScoringNetwork
Scoring network for alpha orbitals.
beta_scorer : OrbitalScoringNetwork
Scoring network for beta orbitals (conditioned on alpha config).
selector : GumbelTopK
Differentiable top-k selector.
Examples
--------
>>> flow = ParticleConservingFlowSampler(
... num_sites=10, n_alpha=2, n_beta=2
... )
>>> configs, unique = flow.sample(batch_size=100)
>>> is_valid, stats = verify_particle_conservation(
... configs, n_orbitals=5, n_alpha=2, n_beta=2
... )
>>> assert is_valid
"""
def __init__(
self,
num_sites: int,
n_alpha: int,
n_beta: int,
hidden_dims: list[int] | None = None,
temperature: float = 1.0,
min_temperature: float = 0.01,
) -> None:
super().__init__()
if num_sites < 2 or num_sites % 2 != 0:
raise ValueError(
f"num_sites must be a positive even integer, got {num_sites}"
)
if hidden_dims is None:
hidden_dims = [128, 64]
n_orbitals = num_sites // 2
if n_alpha < 0 or n_alpha > n_orbitals:
raise ValueError(f"n_alpha must be in [0, {n_orbitals}], got {n_alpha}")
if n_beta < 0 or n_beta > n_orbitals:
raise ValueError(f"n_beta must be in [0, {n_orbitals}], got {n_beta}")
self.num_sites: int = num_sites
self.n_orbitals: int = n_orbitals
self.n_alpha: int = n_alpha
self.n_beta: int = n_beta
self.temperature: float = temperature
self.min_temperature: float = min_temperature
# Scoring networks
self.alpha_scorer: OrbitalScoringNetwork = OrbitalScoringNetwork(
n_orbitals=n_orbitals,
hidden_dims=hidden_dims,
context_dim=0, # Alpha has no conditioning context
)
self.beta_scorer: OrbitalScoringNetwork = OrbitalScoringNetwork(
n_orbitals=n_orbitals,
hidden_dims=hidden_dims,
context_dim=32, # Beta conditioned on alpha
)
# Differentiable top-k selector
self.selector: GumbelTopK = GumbelTopK(
temperature=temperature,
min_temperature=min_temperature,
)
[docs]
def set_temperature(self, temperature: float) -> None:
"""Set the temperature for differentiable top-k selection.
Parameters
----------
temperature : float
New temperature value. Will be clamped to at least
``min_temperature``.
"""
self.temperature = max(temperature, self.min_temperature)
self.selector.temperature = self.temperature
def _soft_to_hard(self, soft_mask: torch.Tensor, k: int) -> torch.Tensor:
"""Convert a soft selection mask to a hard binary mask.
Selects the top-k entries by value and sets them to 1, all
others to 0. Uses straight-through estimation for gradient
flow: the forward pass uses hard values, but the backward pass
uses soft gradients.
Parameters
----------
soft_mask : torch.Tensor
Soft selection, shape ``(batch, n)``.
k : int
Number of elements to select.
Returns
-------
torch.Tensor
Hard binary mask, shape ``(batch, n)``, with exactly ``k``
ones per row.
"""
_, top_indices = torch.topk(soft_mask, k, dim=-1)
hard_mask = torch.zeros_like(soft_mask)
hard_mask.scatter_(1, top_indices, 1.0)
# Straight-through estimator: hard in forward, soft gradients in backward
return hard_mask - soft_mask.detach() + soft_mask
[docs]
def sample(
self,
batch_size: int,
temperature: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Sample particle-conserving binary configurations.
Each configuration has exactly ``n_alpha`` occupied alpha
orbitals and ``n_beta`` occupied beta orbitals.
Parameters
----------
batch_size : int
Number of configurations to sample.
temperature : float or None, optional
Override temperature for this call. If ``None``, uses the
current instance temperature.
Returns
-------
all_configs : torch.Tensor
All sampled configurations, shape ``(batch_size, num_sites)``.
The first ``n_orbitals`` entries are alpha, the remaining are
beta.
unique_configs : torch.Tensor
Unique configurations, shape ``(n_unique, num_sites)``.
"""
temp = temperature if temperature is not None else self.temperature
# Step 1: Score and select alpha orbitals
alpha_logits = self.alpha_scorer(batch_size)
alpha_soft = self.selector(alpha_logits, self.n_alpha, temperature=temp)
alpha_config = self._soft_to_hard(alpha_soft, self.n_alpha)
# Step 2: Score and select beta orbitals, conditioned on alpha
beta_logits = self.beta_scorer(batch_size, context=alpha_config)
beta_soft = self.selector(beta_logits, self.n_beta, temperature=temp)
beta_config = self._soft_to_hard(beta_soft, self.n_beta)
# Step 3: Concatenate [alpha, beta]
all_configs = torch.cat([alpha_config, beta_config], dim=-1)
# Step 4: Extract unique configurations
unique_configs = torch.unique(all_configs, dim=0)
return all_configs, unique_configs
[docs]
def sample_without_replacement(
self,
batch_size: int,
temperature: float | None = None,
) -> torch.Tensor:
"""Sample unique configurations using deterministic ordering.
Generates a larger pool of samples and returns the unique
configurations sorted by their logit scores (most probable first).
Parameters
----------
batch_size : int
Desired number of unique configurations.
temperature : float or None, optional
Override temperature. If ``None``, uses instance temperature.
Returns
-------
torch.Tensor
Unique configurations, shape ``(n_unique, num_sites)`` where
``n_unique <= batch_size``. Sorted by descending score.
"""
# Over-sample to increase chance of getting enough unique configs
oversample_factor = 4
pool_size = batch_size * oversample_factor
all_configs, unique_configs = self.sample(pool_size, temperature=temperature)
if unique_configs.shape[0] >= batch_size:
return unique_configs[:batch_size]
return unique_configs
[docs]
def forward(
self,
batch_size: int,
temperature: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass --- delegates to :meth:`sample`.
Parameters
----------
batch_size : int
Number of configurations to sample.
temperature : float or None, optional
Override temperature.
Returns
-------
all_configs : torch.Tensor
All configurations, shape ``(batch_size, num_sites)``.
unique_configs : torch.Tensor
Unique configurations, shape ``(n_unique, num_sites)``.
"""
return self.sample(batch_size, temperature=temperature)
# ---------------------------------------------------------------------------
# Particle-conservation verification
# ---------------------------------------------------------------------------
[docs]
def verify_particle_conservation(
configs: torch.Tensor,
n_orbitals: int,
n_alpha: int,
n_beta: int,
) -> tuple[bool, dict[str, object]]:
"""Validate that all configurations conserve particle numbers.
Checks that each configuration has exactly ``n_alpha`` occupied
alpha orbitals (first ``n_orbitals`` sites) and ``n_beta`` occupied
beta orbitals (remaining ``n_orbitals`` sites).
Parameters
----------
configs : torch.Tensor
Binary configurations, shape ``(n_configs, 2 * n_orbitals)``.
n_orbitals : int
Number of spatial orbitals (half of ``num_sites``).
n_alpha : int
Expected number of alpha electrons per configuration.
n_beta : int
Expected number of beta electrons per configuration.
Returns
-------
is_valid : bool
``True`` if every configuration has exactly the correct particle
numbers.
stats : dict
Dictionary with detailed statistics:
- ``"n_configs"`` : int --- total number of configurations.
- ``"n_valid"`` : int --- number of valid configurations.
- ``"n_invalid"`` : int --- number of invalid configurations.
- ``"alpha_counts"`` : torch.Tensor --- alpha electron count per config.
- ``"beta_counts"`` : torch.Tensor --- beta electron count per config.
- ``"alpha_violations"`` : int --- configs with wrong alpha count.
- ``"beta_violations"`` : int --- configs with wrong beta count.
Examples
--------
>>> configs = torch.tensor([[1, 1, 0, 1, 0, 1]]) # 2 alpha, 2 beta
>>> is_valid, stats = verify_particle_conservation(configs, 3, 2, 2)
>>> is_valid
True
"""
if configs.ndim != 2:
raise ValueError(f"configs must be 2-dimensional, got shape {configs.shape}")
expected_cols = 2 * n_orbitals
if configs.shape[1] != expected_cols:
raise ValueError(
f"configs must have {expected_cols} columns "
f"(2 * n_orbitals), got {configs.shape[1]}"
)
alpha_part = configs[:, :n_orbitals]
beta_part = configs[:, n_orbitals:]
alpha_counts = alpha_part.sum(dim=-1)
beta_counts = beta_part.sum(dim=-1)
alpha_valid = alpha_counts == n_alpha
beta_valid = beta_counts == n_beta
all_valid = alpha_valid & beta_valid
n_configs = configs.shape[0]
n_valid = int(all_valid.sum().item())
alpha_violations = int((~alpha_valid).sum().item())
beta_violations = int((~beta_valid).sum().item())
stats: dict[str, object] = {
"n_configs": n_configs,
"n_valid": n_valid,
"n_invalid": n_configs - n_valid,
"alpha_counts": alpha_counts,
"beta_counts": beta_counts,
"alpha_violations": alpha_violations,
"beta_violations": beta_violations,
}
is_valid = n_valid == n_configs
return is_valid, stats