Source code for qvartools.flows.training.loss_functions

"""
loss_functions --- Loss computation for physics-guided flow training
====================================================================

Standalone loss functions and supporting utilities extracted from the
physics-guided training loop:

* :func:`compute_teacher_loss` --- KL-divergence teacher loss.
* :func:`compute_physics_loss` --- Variational energy loss with EMA
  baseline for variance reduction.
* :func:`compute_entropy_loss` --- Negative-entropy regularisation.
* :func:`compute_local_energy` --- Per-configuration local energy via
  the Hamiltonian connections.
* :class:`ConnectionCache` --- LRU-style cache for Hamiltonian
  connection lookups.
"""

from __future__ import annotations

from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn

from qvartools.hamiltonians.hamiltonian import Hamiltonian

__all__ = [
    "ConnectionCache",
    "compute_teacher_loss",
    "compute_physics_loss",
    "compute_entropy_loss",
    "compute_local_energy",
]


# ---------------------------------------------------------------------------
# Connection cache
# ---------------------------------------------------------------------------


[docs] class ConnectionCache: """LRU-style cache for Hamiltonian connection lookups. Stores the result of ``hamiltonian.get_connections(config)`` keyed by the configuration as a tuple of integers. Evicts the oldest entries when ``max_size`` is exceeded. Parameters ---------- max_size : int Maximum number of cached entries (default ``100000``). Attributes ---------- max_size : int Capacity limit. """ def __init__(self, max_size: int = 100000) -> None: self.max_size: int = max_size self._cache: Dict[tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
[docs] def get( self, config: torch.Tensor, hamiltonian: Hamiltonian, ) -> Tuple[torch.Tensor, torch.Tensor]: """Retrieve connections for *config*, computing and caching if absent. Parameters ---------- config : torch.Tensor Single configuration, shape ``(num_sites,)``. hamiltonian : Hamiltonian The Hamiltonian to query for connections. Returns ------- connected_configs : torch.Tensor Connected configurations, shape ``(n_conn, num_sites)``. matrix_elements : torch.Tensor Corresponding matrix elements, shape ``(n_conn,)``. """ key = tuple(config.long().tolist()) if key in self._cache: return self._cache[key] connected, elements = hamiltonian.get_connections(config) if len(self._cache) >= self.max_size: # Evict oldest entry oldest_key = next(iter(self._cache)) del self._cache[oldest_key] self._cache[key] = (connected, elements) return connected, elements
[docs] def clear(self) -> None: """Remove all cached entries.""" self._cache.clear()
def __len__(self) -> int: """Return the number of cached entries. Returns ------- int Current number of entries in the cache. """ return len(self._cache)
# --------------------------------------------------------------------------- # Local energy computation # ---------------------------------------------------------------------------
[docs] def compute_local_energy( configs: torch.Tensor, nqs: nn.Module, hamiltonian: Hamiltonian, device: torch.device, connection_cache: Optional[ConnectionCache] = None, ) -> torch.Tensor: """Compute the local energy E_loc(x) for each configuration. ``E_loc(x) = H_{x,x} + sum_{x' != x} H_{x,x'} * psi(x') / psi(x)`` Optimised to minimise CPU-GPU transfers and batch all NQS evaluations into a single call. Parameters ---------- configs : torch.Tensor Configurations, shape ``(batch, num_sites)``. nqs : nn.Module Neural quantum state with a ``log_amplitude(x)`` method. hamiltonian : Hamiltonian The Hamiltonian operator. device : torch.device Torch device for computation. connection_cache : ConnectionCache or None, optional Optional cache for Hamiltonian connections. Returns ------- torch.Tensor Local energies, shape ``(batch,)``. """ batch = configs.shape[0] # --- 1. Evaluate log_amplitude for all input configs (single call) --- with torch.no_grad(): log_amp_x = nqs.log_amplitude(configs) # (batch,) # --- 2. Batch diagonal computation (single vectorised call) ---------- # diagonal_elements_batch handles device conversion internally, # so pass configs directly (avoids unnecessary CPU↔GPU transfer) diag_all = hamiltonian.diagonal_elements_batch(configs) e_loc = diag_all.to(device).float() # (batch,) # --- 3. Gather all off-diagonal connections -------------------------- # Move to CPU once for Numba-based get_connections configs_cpu = configs.cpu() all_connected: list[torch.Tensor] = [] all_elements: list[torch.Tensor] = [] # owner_idx[i] stores which original config the i-th connected config # belongs to, so we can scatter the results back. owner_indices: list[int] = [] for idx in range(batch): config_cpu = configs_cpu[idx] if connection_cache is not None: connected, elements = connection_cache.get( config_cpu, hamiltonian ) else: connected, elements = hamiltonian.get_connections(config_cpu) if connected.numel() > 0: all_connected.append(connected) all_elements.append(elements) owner_indices.extend([idx] * connected.shape[0]) # --- 4. Single batched NQS evaluation for all connections ------------ if all_connected: all_connected_cat = torch.cat(all_connected, dim=0) # (N_total, sites) all_elements_cat = torch.cat(all_elements, dim=0) # (N_total,) connected_dev = all_connected_cat.to(device).float() elements_dev = all_elements_cat.to(device).float() owner_dev = torch.tensor( owner_indices, device=device, dtype=torch.long ) with torch.no_grad(): log_amp_conn = nqs.log_amplitude(connected_dev) # (N_total,) # psi(x') / psi(x) = exp(log_amp(x') - log_amp(x)) ratios = torch.exp(log_amp_conn - log_amp_x[owner_dev]) contributions = elements_dev * ratios # (N_total,) # Scatter-add contributions back to corresponding configs e_loc.scatter_add_(0, owner_dev, contributions) return e_loc
# --------------------------------------------------------------------------- # Loss functions # ---------------------------------------------------------------------------
[docs] def compute_teacher_loss( configs: torch.Tensor, log_probs_flow: torch.Tensor, nqs: nn.Module, ) -> torch.Tensor: """Compute the teacher (KL divergence) loss. ``L_teacher = -sum_x p_nqs(x) * log p_flow(x)`` The NQS probabilities are detached (treated as fixed targets). Parameters ---------- configs : torch.Tensor Sampled configurations, shape ``(batch, num_sites)``. log_probs_flow : torch.Tensor Flow log-probabilities, shape ``(batch,)``. nqs : nn.Module Neural quantum state with a ``log_amplitude(x)`` method. Returns ------- torch.Tensor Scalar teacher loss. """ with torch.no_grad(): log_amp = nqs.log_amplitude(configs) log_prob_nqs = 2.0 * log_amp # log |psi|^2 (unnormalised) # Normalise within the batch log_z = torch.logsumexp(log_prob_nqs, dim=0) weights = torch.exp(log_prob_nqs - log_z) return -(weights * log_probs_flow).sum()
[docs] def compute_physics_loss( configs: torch.Tensor, nqs: nn.Module, hamiltonian: Hamiltonian, device: torch.device, energy_baseline: float, baseline_initialized: bool, use_energy_baseline: bool, ema_decay: float, connection_cache: Optional[ConnectionCache] = None, ) -> Tuple[torch.Tensor, float, float, bool]: """Compute the variational energy (physics) loss. ``L_physics = sum_x |psi(x)|^2 * E_loc(x) / Z`` Uses a running EMA baseline for variance reduction when enabled. Parameters ---------- configs : torch.Tensor Sampled configurations, shape ``(batch, num_sites)``. nqs : nn.Module Neural quantum state with a ``log_amplitude(x)`` method. hamiltonian : Hamiltonian The Hamiltonian operator. device : torch.device Torch device for computation. energy_baseline : float Current EMA energy baseline value. baseline_initialized : bool Whether the baseline has been initialised. use_energy_baseline : bool Whether to apply variance reduction via EMA baseline. ema_decay : float Exponential moving average decay for the baseline. connection_cache : ConnectionCache or None, optional Optional cache for Hamiltonian connections. Returns ------- loss : torch.Tensor Scalar physics loss. mean_energy : float Mean local energy (for logging). updated_baseline : float Updated EMA energy baseline. updated_initialized : bool Whether the baseline is now initialised. """ e_loc = compute_local_energy( configs, nqs, hamiltonian, device, connection_cache ) log_amp = nqs.log_amplitude(configs) log_prob = 2.0 * log_amp log_z = torch.logsumexp(log_prob, dim=0) weights = torch.exp(log_prob - log_z) mean_energy = float((weights.detach() * e_loc).sum()) # Variance reduction with EMA baseline updated_baseline = energy_baseline updated_initialized = baseline_initialized if use_energy_baseline: if not baseline_initialized: updated_baseline = mean_energy updated_initialized = True else: updated_baseline = ( ema_decay * energy_baseline + (1.0 - ema_decay) * mean_energy ) centred_e = e_loc - updated_baseline else: centred_e = e_loc loss = (weights * centred_e).sum() return loss, mean_energy, updated_baseline, updated_initialized
[docs] def compute_entropy_loss( log_probs_flow: torch.Tensor, ) -> torch.Tensor: """Compute the negative entropy of the flow distribution. ``L_entropy = sum_x p_flow(x) * log p_flow(x) = -H[p_flow]`` Minimising this loss maximises the entropy. Parameters ---------- log_probs_flow : torch.Tensor Flow log-probabilities, shape ``(batch,)``. Returns ------- torch.Tensor Scalar entropy loss (negative entropy). """ return log_probs_flow.mean()