Source code for qvartools.flows.training.physics_guided_training

"""
physics_guided_training --- Mixed-objective flow + NQS trainer
==============================================================

Implements :class:`PhysicsGuidedFlowTrainer`, a training orchestrator
that jointly optimises a normalizing flow and a neural quantum state
using a combination of three loss terms:

* **Teacher loss** --- trains the flow to match the NQS distribution by
  maximising ``log p_flow(x)`` weighted by the NQS probability.
* **Physics loss** --- minimises the variational energy via the local
  energy estimator ``E_loc(x) = sum_x' H_{x,x'} psi(x') / psi(x)``.
* **Entropy loss** --- encourages exploration by maximising the entropy
  of the flow distribution.

The training loop includes temperature annealing for particle-conserving
flows, essential-configuration injection (Hartree--Fock + singles +
doubles), and convergence detection based on the unique-configuration
ratio.
"""

from __future__ import annotations

import logging
import math
from dataclasses import dataclass

import torch
import torch.nn as nn

from qvartools.flows.training.loss_functions import (
    ConnectionCache,
    compute_entropy_loss,
    compute_physics_loss,
    compute_teacher_loss,
)
from qvartools.hamiltonians.hamiltonian import Hamiltonian

__all__ = [
    "PhysicsGuidedConfig",
    "PhysicsGuidedFlowTrainer",
]

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Configuration dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class PhysicsGuidedConfig: """Hyperparameters for :class:`PhysicsGuidedFlowTrainer`. All fields have sensible defaults for molecular-scale problems. The class is frozen (immutable) to prevent accidental mutation during training. Parameters ---------- samples_per_batch : int Number of flow samples per mini-batch (default ``500``). num_batches : int Number of mini-batches per epoch (default ``10``). num_epochs : int Maximum number of training epochs (default ``200``). min_epochs : int Minimum epochs before convergence checks activate (default ``50``). convergence_threshold : float Training stops when the unique-configuration ratio changes by less than this amount over consecutive epochs (default ``0.01``). flow_lr : float Learning rate for the flow optimiser (default ``1e-3``). nqs_lr : float Learning rate for the NQS optimiser (default ``1e-3``). teacher_weight : float Weight of the teacher (KL) loss (default ``1.0``). physics_weight : float Weight of the variational energy loss (default ``0.0``). entropy_weight : float Weight of the entropy regularisation loss (default ``0.0``). use_energy_baseline : bool Whether to subtract a running baseline from the energy for variance reduction (default ``True``). ema_decay : float Exponential moving average decay for the energy baseline (default ``0.99``). use_connection_cache : bool Whether to cache Hamiltonian connections for repeated configs (default ``True``). max_cache_size : int Maximum number of entries in the connection cache (default ``100000``). initial_temperature : float Starting temperature for flow annealing (default ``2.0``). final_temperature : float Final temperature after annealing (default ``0.1``). temperature_decay_epochs : int Number of epochs over which to anneal temperature (default ``100``). inject_essential_configs : bool Whether to inject Hartree--Fock and nearby configurations into the basis (default ``True``). include_singles_in_basis : bool Whether to include single excitations in the essential basis (default ``True``). include_doubles_in_basis : bool Whether to include double excitations in the essential basis (default ``True``). device : str Torch device for training (default ``"cpu"``). """ samples_per_batch: int = 500 num_batches: int = 10 num_epochs: int = 200 min_epochs: int = 50 convergence_threshold: float = 0.01 flow_lr: float = 1e-3 nqs_lr: float = 1e-3 teacher_weight: float = 1.0 physics_weight: float = 0.0 entropy_weight: float = 0.0 use_energy_baseline: bool = True ema_decay: float = 0.99 use_connection_cache: bool = True max_cache_size: int = 100000 initial_temperature: float = 2.0 final_temperature: float = 0.1 temperature_decay_epochs: int = 100 inject_essential_configs: bool = True include_singles_in_basis: bool = True include_doubles_in_basis: bool = True device: str = "cpu"
# --------------------------------------------------------------------------- # Essential configuration generators # --------------------------------------------------------------------------- def _generate_hf_config(n_orbitals: int, n_alpha: int, n_beta: int) -> torch.Tensor: """Generate the Hartree--Fock reference configuration. The HF configuration occupies the lowest-energy orbitals: the first ``n_alpha`` alpha orbitals and the first ``n_beta`` beta orbitals. Parameters ---------- n_orbitals : int Number of spatial orbitals. n_alpha : int Number of alpha electrons. n_beta : int Number of beta electrons. Returns ------- torch.Tensor Binary HF configuration, shape ``(2 * n_orbitals,)``. """ config = torch.zeros(2 * n_orbitals, dtype=torch.float32) config[:n_alpha] = 1.0 config[n_orbitals : n_orbitals + n_beta] = 1.0 return config def _generate_single_excitations( hf_config: torch.Tensor, n_orbitals: int, n_alpha: int, n_beta: int, ) -> torch.Tensor: """Generate all single excitations from the HF configuration. Parameters ---------- hf_config : torch.Tensor The Hartree--Fock reference, shape ``(2 * n_orbitals,)``. n_orbitals : int Number of spatial orbitals. n_alpha : int Number of alpha electrons. n_beta : int Number of beta electrons. Returns ------- torch.Tensor Single-excitation configurations, shape ``(n_singles, 2 * n_orbitals)``. """ singles: list[torch.Tensor] = [] # Alpha single excitations: i (occupied) -> a (virtual) for i in range(n_alpha): for a in range(n_alpha, n_orbitals): config = hf_config.clone() config[i] = 0.0 config[a] = 1.0 singles.append(config) # Beta single excitations for i in range(n_orbitals, n_orbitals + n_beta): for a in range(n_orbitals + n_beta, 2 * n_orbitals): config = hf_config.clone() config[i] = 0.0 config[a] = 1.0 singles.append(config) if not singles: return torch.empty(0, 2 * n_orbitals, dtype=torch.float32) return torch.stack(singles) def _generate_double_excitations( hf_config: torch.Tensor, n_orbitals: int, n_alpha: int, n_beta: int, ) -> torch.Tensor: """Generate all double excitations from the HF configuration. Includes alpha-alpha, beta-beta, and alpha-beta double excitations. Parameters ---------- hf_config : torch.Tensor The Hartree--Fock reference, shape ``(2 * n_orbitals,)``. n_orbitals : int Number of spatial orbitals. n_alpha : int Number of alpha electrons. n_beta : int Number of beta electrons. Returns ------- torch.Tensor Double-excitation configurations, shape ``(n_doubles, 2 * n_orbitals)``. """ doubles: list[torch.Tensor] = [] alpha_occ = list(range(n_alpha)) alpha_vir = list(range(n_alpha, n_orbitals)) beta_occ = list(range(n_orbitals, n_orbitals + n_beta)) beta_vir = list(range(n_orbitals + n_beta, 2 * n_orbitals)) # Alpha-alpha doubles for i_idx in range(len(alpha_occ)): for j_idx in range(i_idx + 1, len(alpha_occ)): i, j = alpha_occ[i_idx], alpha_occ[j_idx] for a_idx in range(len(alpha_vir)): for b_idx in range(a_idx + 1, len(alpha_vir)): a, b = alpha_vir[a_idx], alpha_vir[b_idx] config = hf_config.clone() config[i] = 0.0 config[j] = 0.0 config[a] = 1.0 config[b] = 1.0 doubles.append(config) # Beta-beta doubles for i_idx in range(len(beta_occ)): for j_idx in range(i_idx + 1, len(beta_occ)): i, j = beta_occ[i_idx], beta_occ[j_idx] for a_idx in range(len(beta_vir)): for b_idx in range(a_idx + 1, len(beta_vir)): a, b = beta_vir[a_idx], beta_vir[b_idx] config = hf_config.clone() config[i] = 0.0 config[j] = 0.0 config[a] = 1.0 config[b] = 1.0 doubles.append(config) # Alpha-beta doubles for i in alpha_occ: for a in alpha_vir: for j in beta_occ: for b in beta_vir: config = hf_config.clone() config[i] = 0.0 config[a] = 1.0 config[j] = 0.0 config[b] = 1.0 doubles.append(config) if not doubles: return torch.empty(0, 2 * n_orbitals, dtype=torch.float32) return torch.stack(doubles) # --------------------------------------------------------------------------- # PhysicsGuidedFlowTrainer # ---------------------------------------------------------------------------
[docs] class PhysicsGuidedFlowTrainer: """Mixed-objective trainer for joint flow + NQS optimisation. Combines three loss terms with configurable weights: * **Teacher loss**: ``-sum_x p_nqs(x) * log p_flow(x)`` --- trains the flow to reproduce the NQS distribution. * **Physics loss**: variational energy ``E = sum_x |psi(x)|^2 * E_loc(x) / Z`` where ``E_loc(x) = sum_{x'} H_{x,x'} * psi(x') / psi(x)`` --- minimises the ground-state energy estimate. * **Entropy loss**: ``-H[p_flow]`` (negative entropy) --- prevents mode collapse by encouraging distribution spread. The trainer also: * Accumulates unique configurations into a growing basis set. * Anneals the flow temperature over early epochs. * Optionally injects essential (HF + singles + doubles) configurations into the basis. * Caches Hamiltonian connections for efficiency. Parameters ---------- flow : nn.Module The normalizing flow sampler. Must implement ``sample(batch_size)`` returning ``(all_configs, unique_configs)``. nqs : nn.Module The neural quantum state. Must implement ``log_amplitude(x)`` returning log-amplitudes of shape ``(batch,)``. hamiltonian : Hamiltonian The Hamiltonian operator. config : PhysicsGuidedConfig Training hyperparameters. device : str, optional Torch device override (default uses ``config.device``). Attributes ---------- flow : nn.Module The flow model. nqs : nn.Module The NQS model. hamiltonian : Hamiltonian The Hamiltonian. config : PhysicsGuidedConfig Training configuration. device : torch.device Active device. accumulated_basis : torch.Tensor or None Growing set of unique configurations seen during training. flow_optimizer : torch.optim.Adam Optimiser for the flow parameters. nqs_optimizer : torch.optim.Adam Optimiser for the NQS parameters. energy_baseline : float Running EMA baseline for variance reduction. connection_cache : ConnectionCache or None Cache for Hamiltonian connections. """ def __init__( self, flow: nn.Module, nqs: nn.Module, hamiltonian: Hamiltonian, config: PhysicsGuidedConfig, device: str = "cpu", ) -> None: effective_device = device if device != "cpu" else config.device self.device: torch.device = torch.device(effective_device) self.flow: nn.Module = flow.to(self.device) self.nqs: nn.Module = nqs.to(self.device) self.hamiltonian: Hamiltonian = hamiltonian self.config: PhysicsGuidedConfig = config self.accumulated_basis: torch.Tensor | None = None self.flow_optimizer: torch.optim.Adam = torch.optim.Adam( flow.parameters(), lr=config.flow_lr ) self.nqs_optimizer: torch.optim.Adam = torch.optim.Adam( nqs.parameters(), lr=config.nqs_lr ) self.energy_baseline: float = 0.0 self._baseline_initialized: bool = False self.connection_cache: ConnectionCache | None = None if config.use_connection_cache: self.connection_cache = ConnectionCache(max_size=config.max_cache_size) # Inject essential configurations if requested if config.inject_essential_configs: self._inject_essential_configs() def _inject_essential_configs(self) -> None: """Inject HF + singles + doubles into the accumulated basis.""" num_sites = self.hamiltonian.num_sites n_orbitals = num_sites // 2 # Attempt to infer n_alpha, n_beta from the flow n_alpha = getattr(self.flow, "n_alpha", None) n_beta = getattr(self.flow, "n_beta", None) if n_alpha is None or n_beta is None: logger.warning( "Cannot inject essential configs: flow does not expose " "n_alpha and n_beta attributes." ) return essential: list[torch.Tensor] = [] # Hartree-Fock reference hf = _generate_hf_config(n_orbitals, n_alpha, n_beta) essential.append(hf) # Single excitations if self.config.include_singles_in_basis: singles = _generate_single_excitations(hf, n_orbitals, n_alpha, n_beta) if singles.shape[0] > 0: essential.append(singles) # Double excitations if self.config.include_doubles_in_basis: doubles = _generate_double_excitations(hf, n_orbitals, n_alpha, n_beta) if doubles.shape[0] > 0: essential.append(doubles) if essential: all_essential = torch.cat( [e.unsqueeze(0) if e.ndim == 1 else e for e in essential], dim=0, ).to(self.device) self._accumulate_configs(all_essential) logger.info( "Injected %d essential configurations into basis.", all_essential.shape[0], ) def _accumulate_configs(self, new_configs: torch.Tensor) -> None: """Add new unique configurations to the accumulated basis. Parameters ---------- new_configs : torch.Tensor New configurations, shape ``(n, num_sites)``. """ if new_configs.numel() == 0: return new_configs = new_configs.to(self.device) if self.accumulated_basis is None: self.accumulated_basis = torch.unique(new_configs, dim=0) else: combined = torch.cat([self.accumulated_basis, new_configs], dim=0) self.accumulated_basis = torch.unique(combined, dim=0) def _get_temperature(self, epoch: int) -> float: """Compute the annealed temperature for the current epoch. Uses exponential decay from ``initial_temperature`` to ``final_temperature`` over ``temperature_decay_epochs``. Parameters ---------- epoch : int Current epoch index (0-based). Returns ------- float Temperature for the current epoch. """ cfg = self.config if epoch >= cfg.temperature_decay_epochs: return cfg.final_temperature decay_rate = math.log(cfg.initial_temperature / cfg.final_temperature) progress = epoch / max(cfg.temperature_decay_epochs, 1) return cfg.initial_temperature * math.exp(-decay_rate * progress) def _train_epoch(self, epoch: int) -> dict[str, float]: """Execute a single training epoch. Samples configurations from the flow, computes the combined loss, updates both flow and NQS parameters, and accumulates unique configurations into the basis. Parameters ---------- epoch : int Current epoch index (0-based). Returns ------- dict Epoch metrics with keys: - ``"teacher_loss"`` : float - ``"physics_loss"`` : float - ``"entropy_loss"`` : float - ``"total_loss"`` : float - ``"mean_energy"`` : float - ``"unique_ratio"`` : float - ``"basis_size"`` : int - ``"temperature"`` : float """ cfg = self.config self.flow.train() self.nqs.train() # Temperature annealing temperature = self._get_temperature(epoch) if hasattr(self.flow, "set_temperature"): self.flow.set_temperature(temperature) epoch_teacher = 0.0 epoch_physics = 0.0 epoch_entropy = 0.0 epoch_total = 0.0 epoch_energy = 0.0 total_samples = 0 total_unique = 0 for _ in range(cfg.num_batches): # Sample from flow sample_result = self.flow.sample(cfg.samples_per_batch) if len(sample_result) == 2: all_configs, unique_configs = sample_result else: # Handle flows that return extra outputs all_configs, unique_configs = sample_result[0], sample_result[1] all_configs = all_configs.to(self.device) unique_configs = unique_configs.to(self.device) total_samples += all_configs.shape[0] total_unique += unique_configs.shape[0] # Accumulate basis self._accumulate_configs(unique_configs) # Compute flow log-probabilities for loss terms # Use continuous log-prob if available, else approximate if hasattr(self.flow, "log_prob_continuous"): # Map discrete configs to continuous space center points y_approx = 2.0 * all_configs - 1.0 # {0,1} -> {-1,+1} log_probs_flow = self.flow.log_prob_continuous(y_approx) else: # Fallback: use uniform log-prob (no teacher signal) log_probs_flow = torch.zeros(all_configs.shape[0], device=self.device) # Compute losses loss = torch.tensor(0.0, device=self.device) teacher_loss_val = 0.0 if cfg.teacher_weight > 0.0: t_loss = compute_teacher_loss(all_configs, log_probs_flow, self.nqs) loss = loss + cfg.teacher_weight * t_loss teacher_loss_val = float(t_loss.detach()) physics_loss_val = 0.0 energy_val = 0.0 if cfg.physics_weight > 0.0: p_loss, energy_val, self.energy_baseline, self._baseline_initialized = ( compute_physics_loss( all_configs, self.nqs, self.hamiltonian, self.device, self.energy_baseline, self._baseline_initialized, cfg.use_energy_baseline, cfg.ema_decay, self.connection_cache, ) ) loss = loss + cfg.physics_weight * p_loss physics_loss_val = float(p_loss.detach()) entropy_loss_val = 0.0 if cfg.entropy_weight > 0.0: e_loss = compute_entropy_loss(log_probs_flow) loss = loss + cfg.entropy_weight * e_loss entropy_loss_val = float(e_loss.detach()) # Backward pass and optimiser steps self.flow_optimizer.zero_grad() self.nqs_optimizer.zero_grad() if loss.requires_grad: loss.backward() self.flow_optimizer.step() self.nqs_optimizer.step() epoch_teacher += teacher_loss_val epoch_physics += physics_loss_val epoch_entropy += entropy_loss_val epoch_total += float(loss.detach()) if loss.requires_grad else float(loss) epoch_energy += energy_val n_batches = max(cfg.num_batches, 1) unique_ratio = total_unique / max(total_samples, 1) basis_size = ( self.accumulated_basis.shape[0] if self.accumulated_basis is not None else 0 ) return { "teacher_loss": epoch_teacher / n_batches, "physics_loss": epoch_physics / n_batches, "entropy_loss": epoch_entropy / n_batches, "total_loss": epoch_total / n_batches, "mean_energy": epoch_energy / n_batches, "unique_ratio": unique_ratio, "basis_size": basis_size, "temperature": temperature, }
[docs] def train(self, progress: bool = True) -> dict[str, list]: """Run the full training loop. Trains for up to ``config.num_epochs`` epochs, with early stopping when the unique-configuration ratio converges (change less than ``config.convergence_threshold`` for two consecutive epochs after ``config.min_epochs``). Parameters ---------- progress : bool, optional If ``True``, log epoch-level metrics at INFO level (default ``True``). Returns ------- dict Training history with keys matching the epoch metrics, each mapping to a list of per-epoch values: - ``"teacher_loss"`` : list of float - ``"physics_loss"`` : list of float - ``"entropy_loss"`` : list of float - ``"total_loss"`` : list of float - ``"mean_energy"`` : list of float - ``"unique_ratio"`` : list of float - ``"basis_size"`` : list of int - ``"temperature"`` : list of float """ history: dict[str, list] = { "teacher_loss": [], "physics_loss": [], "entropy_loss": [], "total_loss": [], "mean_energy": [], "unique_ratio": [], "basis_size": [], "temperature": [], } prev_unique_ratio = 0.0 for epoch in range(self.config.num_epochs): metrics = self._train_epoch(epoch) for key in history: history[key].append(metrics[key]) if progress: logger.info( "Epoch %3d | loss=%.4f | energy=%.6f | " "unique_ratio=%.4f | basis=%d | temp=%.3f", epoch, metrics["total_loss"], metrics["mean_energy"], metrics["unique_ratio"], metrics["basis_size"], metrics["temperature"], ) # Convergence check if epoch >= self.config.min_epochs: delta = abs(metrics["unique_ratio"] - prev_unique_ratio) if delta < self.config.convergence_threshold: logger.info( "Converged at epoch %d (unique_ratio delta=%.6f < %.6f).", epoch, delta, self.config.convergence_threshold, ) break prev_unique_ratio = metrics["unique_ratio"] return history