Source code for qvartools.nqs.neural_state

"""
base --- Abstract neural quantum state base class
==================================================

Provides the :class:`NeuralQuantumState` ABC that every concrete NQS
architecture must implement.  Subclasses define :meth:`log_amplitude`
and :meth:`phase`; the base class assembles these into the full
wavefunction, probabilities, and normalised probabilities.
"""

from __future__ import annotations

from abc import ABC, abstractmethod

import torch
import torch.nn as nn

__all__ = [
    "NeuralQuantumState",
]


[docs] class NeuralQuantumState(nn.Module, ABC): """Abstract base class for all neural quantum state ansaetze. Every subclass must implement :meth:`log_amplitude` and :meth:`phase`. The base class provides convenience methods for evaluating the full log-wavefunction, the wavefunction itself, Born-rule probabilities, and normalised probabilities over a discrete basis set. Parameters ---------- num_sites : int Number of lattice / orbital sites (input dimensionality). local_dim : int, optional Dimension of the local Hilbert space on each site (default ``2`` for spin-1/2 / qubit systems). complex_output : bool, optional If ``True`` the NQS represents a complex-valued wavefunction with a non-trivial phase network. If ``False`` the phase is identically zero and :meth:`log_psi` returns a single real tensor. Attributes ---------- num_sites : int Number of sites. local_dim : int Local Hilbert-space dimension. complex_output : bool Whether the NQS has a non-trivial phase. """ def __init__( self, num_sites: int, local_dim: int = 2, complex_output: bool = False, ) -> None: super().__init__() if num_sites < 1: raise ValueError(f"num_sites must be >= 1, got {num_sites}") if local_dim < 2: raise ValueError(f"local_dim must be >= 2, got {local_dim}") self.num_sites: int = num_sites self.local_dim: int = local_dim self.complex_output: bool = complex_output # ------------------------------------------------------------------ # Abstract interface # ------------------------------------------------------------------
[docs] @abstractmethod def log_amplitude(self, x: torch.Tensor) -> torch.Tensor: """Compute the log-amplitude ln|psi(x)| for each configuration. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Log-amplitudes, shape ``(batch,)``. """
[docs] @abstractmethod def phase(self, x: torch.Tensor) -> torch.Tensor: """Compute the phase arg(psi(x)) for each configuration. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Phases in radians, shape ``(batch,)``. Must be identically zero when ``complex_output is False``. """
# ------------------------------------------------------------------ # Concrete methods # ------------------------------------------------------------------
[docs] def log_psi( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Compute the log-wavefunction. For real-valued NQS (``complex_output is False``), returns only the log-amplitude. For complex-valued NQS, returns a tuple of (log_amplitude, phase). Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor or tuple of torch.Tensor If ``complex_output is False``: log-amplitude tensor of shape ``(batch,)``. If ``complex_output is True``: tuple ``(log_amp, phase)`` each of shape ``(batch,)``. """ log_amp = self.log_amplitude(x) if not self.complex_output: return log_amp return log_amp, self.phase(x)
[docs] def psi(self, x: torch.Tensor) -> torch.Tensor: """Evaluate the full wavefunction psi(x). Computes ``exp(log_amp) * exp(i * phase)``. When ``complex_output is False`` the result is real-valued (dtype matches input); otherwise it is complex. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Wavefunction values, shape ``(batch,)``. Complex dtype when ``complex_output is True``. """ log_amp = self.log_amplitude(x) if not self.complex_output: return torch.exp(log_amp) phi = self.phase(x) amplitude = torch.exp(log_amp) return amplitude * torch.exp(1j * phi.to(torch.complex64))
[docs] def probability(self, x: torch.Tensor) -> torch.Tensor: """Compute the Born-rule probability |psi(x)|^2 (unnormalised). Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Unnormalised probabilities ``exp(2 * log_amplitude(x))``, shape ``(batch,)``. """ return torch.exp(2.0 * self.log_amplitude(x))
[docs] def normalized_probability( self, x: torch.Tensor, basis_set: torch.Tensor ) -> torch.Tensor: """Compute normalised Born-rule probabilities over a basis set. The normalisation constant Z is computed as the sum of ``|psi(s)|^2`` over every configuration *s* in *basis_set*. Parameters ---------- x : torch.Tensor Configurations to evaluate, shape ``(batch, num_sites)``. basis_set : torch.Tensor Complete (or reference) set of configurations used to compute the partition function, shape ``(n_basis, num_sites)``. Returns ------- torch.Tensor Normalised probabilities, shape ``(batch,)``. Each entry is ``|psi(x_i)|^2 / Z``. """ log_amp_x = self.log_amplitude(x) log_amp_basis = self.log_amplitude(basis_set) # Compute log(Z) = log(sum exp(2 * log_amp)) via logsumexp for # numerical stability. log_z = torch.logsumexp(2.0 * log_amp_basis, dim=0) return torch.exp(2.0 * log_amp_x - log_z)
[docs] @staticmethod def encode_configuration(config: torch.Tensor) -> torch.Tensor: """Convert a configuration tensor to float for network input. Parameters ---------- config : torch.Tensor Configuration tensor of any integer or float dtype, shape ``(..., num_sites)``. Returns ------- torch.Tensor Float tensor with the same shape, dtype ``torch.float32``. """ return config.to(torch.float32)
# ------------------------------------------------------------------ # Forward (default delegates to log_psi) # ------------------------------------------------------------------
[docs] def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Forward pass --- delegates to :meth:`log_psi`. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor or tuple of torch.Tensor Same as :meth:`log_psi`. """ return self.log_psi(x)