"""
dense --- Dense feedforward NQS architectures
==============================================
Provides :class:`DenseNQS` (a standard fully connected NQS) and
:class:`SignedDenseNQS` (a dense NQS with explicit sign structure via
separate amplitude and sign heads sharing a feature extractor).
Also provides :func:`compile_nqs`, a utility that applies
``torch.compile`` to any NQS model with graceful fallback on error.
"""
from __future__ import annotations
import logging
import torch
import torch.nn as nn
from qvartools.nqs.neural_state import NeuralQuantumState
__all__ = [
"DenseNQS",
"SignedDenseNQS",
"compile_nqs",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# compile_nqs utility
# ---------------------------------------------------------------------------
[docs]
def compile_nqs(
model: nn.Module,
mode: str = "reduce-overhead",
) -> nn.Module:
"""Apply ``torch.compile`` to an NQS model with graceful fallback.
Parameters
----------
model : nn.Module
The neural quantum state model to compile.
mode : str, optional
Compilation mode passed to ``torch.compile``. Common choices are
``"reduce-overhead"`` (default) and ``"max-autotune"``.
Returns
-------
nn.Module
The compiled model, or the original model unchanged if compilation
fails (e.g. unsupported platform or PyTorch version).
Examples
--------
>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32])
>>> nqs = compile_nqs(nqs, mode="max-autotune")
"""
try:
compiled = torch.compile(model, mode=mode)
logger.info("Successfully compiled NQS model with mode='%s'.", mode)
return compiled # type: ignore[return-value]
except Exception as exc: # noqa: BLE001
logger.warning("torch.compile failed (%s). Returning uncompiled model.", exc)
return model
# ---------------------------------------------------------------------------
# Helper: build a stack of Linear + activation layers
# ---------------------------------------------------------------------------
def _build_mlp(
input_dim: int,
hidden_dims: list[int],
output_dim: int,
activation: nn.Module = nn.ReLU(),
output_activation: nn.Module | None = None,
) -> nn.Sequential:
"""Build a simple MLP as an ``nn.Sequential``.
Parameters
----------
input_dim : int
Size of the input feature vector.
hidden_dims : list of int
Sizes of each hidden layer.
output_dim : int
Size of the output layer.
activation : nn.Module, optional
Activation function applied after each hidden layer
(default :class:`~torch.nn.ReLU`).
output_activation : nn.Module or None, optional
Optional activation applied after the output layer.
Returns
-------
nn.Sequential
The assembled MLP.
"""
layers: list[nn.Module] = []
prev_dim = input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, h_dim))
layers.append(activation)
prev_dim = h_dim
layers.append(nn.Linear(prev_dim, output_dim))
if output_activation is not None:
layers.append(output_activation)
return nn.Sequential(*layers)
# ---------------------------------------------------------------------------
# DenseNQS
# ---------------------------------------------------------------------------
[docs]
class DenseNQS(NeuralQuantumState):
"""Fully connected feedforward neural quantum state.
The amplitude network maps a configuration vector to a scalar
log-amplitude via a stack of Linear + ReLU layers, followed by a
final Linear + Tanh layer whose output is scaled by a learnable
``log_amp_scale`` parameter.
If ``complex_output`` is ``True``, a separate phase network of the
same depth produces the wavefunction phase in ``(-pi, pi)``.
Parameters
----------
num_sites : int
Number of lattice / orbital sites.
hidden_dims : list of int, optional
Hidden-layer sizes for the amplitude (and phase) networks
(default ``[128, 64]``).
complex_output : bool, optional
Whether to include a phase network (default ``False``).
Attributes
----------
amplitude_net : nn.Sequential
The amplitude MLP (output before scaling).
log_amp_scale : nn.Parameter
Learnable scalar that multiplies the Tanh output.
phase_net : nn.Sequential or None
Phase MLP when ``complex_output is True``, else ``None``.
Examples
--------
>>> nqs = DenseNQS(num_sites=10, hidden_dims=[64, 32])
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x) # shape (8,)
"""
def __init__(
self,
num_sites: int,
hidden_dims: list[int] | None = None,
complex_output: bool = False,
) -> None:
super().__init__(
num_sites=num_sites,
local_dim=2,
complex_output=complex_output,
)
if hidden_dims is None:
hidden_dims = [128, 64]
# Amplitude network: Input -> [Linear+ReLU]... -> Linear -> Tanh
self.amplitude_net: nn.Sequential = _build_mlp(
input_dim=num_sites,
hidden_dims=hidden_dims,
output_dim=1,
activation=nn.ReLU(),
output_activation=nn.Tanh(),
)
# Learnable scale for the log-amplitude
self.log_amp_scale: nn.Parameter = nn.Parameter(torch.tensor(1.0))
# Phase network (optional)
self.phase_net: nn.Sequential | None = None
if complex_output:
self.phase_net = _build_mlp(
input_dim=num_sites,
hidden_dims=hidden_dims,
output_dim=1,
activation=nn.ReLU(),
output_activation=nn.Tanh(), # output in (-1, 1)
)
[docs]
def log_amplitude(self, x: torch.Tensor) -> torch.Tensor:
"""Compute log-amplitude ln|psi(x)|.
The raw amplitude network output (bounded in ``(-1, 1)`` by Tanh)
is multiplied by the learnable ``log_amp_scale``.
Parameters
----------
x : torch.Tensor
Batch of configurations, shape ``(batch, num_sites)``.
Returns
-------
torch.Tensor
Log-amplitudes, shape ``(batch,)``.
"""
x = self.encode_configuration(x)
raw = self.amplitude_net(x).squeeze(-1) # (batch,)
return raw * self.log_amp_scale
[docs]
def phase(self, x: torch.Tensor) -> torch.Tensor:
"""Compute the wavefunction phase.
Returns zeros for real-valued NQS. For complex NQS the phase
network output (in ``(-1, 1)``) is scaled by ``pi`` so the phase
lies in ``(-pi, pi)``.
Parameters
----------
x : torch.Tensor
Batch of configurations, shape ``(batch, num_sites)``.
Returns
-------
torch.Tensor
Phases in radians, shape ``(batch,)``.
"""
if self.phase_net is None:
return torch.zeros(x.shape[0], device=x.device, dtype=torch.float32)
x = self.encode_configuration(x)
raw = self.phase_net(x).squeeze(-1) # (batch,)
return raw * torch.pi
# ---------------------------------------------------------------------------
# SignedDenseNQS
# ---------------------------------------------------------------------------
[docs]
class SignedDenseNQS(NeuralQuantumState):
"""Dense NQS with explicit sign structure.
Uses a shared feature extractor whose output feeds into two heads:
* **Amplitude head** --- produces the log-amplitude via Softplus to
ensure non-negative output.
* **Sign head** --- produces a logit whose sigmoid is thresholded at
0.5 to yield a phase of either 0 (positive) or pi (negative).
Feature caching avoids redundant computation when :meth:`log_amplitude`
and :meth:`phase` are called on the same input within one evaluation.
Parameters
----------
num_sites : int
Number of lattice / orbital sites.
hidden_dims : list of int, optional
Hidden-layer sizes for the shared feature extractor
(default ``[128, 64]``).
Attributes
----------
feature_net : nn.Sequential
Shared feature extractor.
amplitude_head : nn.Sequential
Maps features to log-amplitude (Softplus output).
sign_head : nn.Linear
Maps features to a sign logit.
Examples
--------
>>> nqs = SignedDenseNQS(num_sites=10)
>>> x = torch.randint(0, 2, (8, 10)).float()
>>> log_amp = nqs.log_amplitude(x) # shape (8,)
>>> phi = nqs.phase(x) # shape (8,), values in {0, pi}
"""
def __init__(
self,
num_sites: int,
hidden_dims: list[int] | None = None,
) -> None:
super().__init__(
num_sites=num_sites,
local_dim=2,
complex_output=True,
)
if hidden_dims is None:
hidden_dims = [128, 64]
# Shared feature extractor
self.feature_net: nn.Sequential = _build_mlp(
input_dim=num_sites,
hidden_dims=hidden_dims[:-1] if len(hidden_dims) > 1 else [],
output_dim=hidden_dims[-1],
activation=nn.ReLU(),
output_activation=nn.ReLU(),
)
feature_dim = hidden_dims[-1]
# Amplitude head: features -> Softplus -> scalar
self.amplitude_head: nn.Sequential = nn.Sequential(
nn.Linear(feature_dim, 1),
nn.Softplus(),
)
# Sign head: features -> logit
self.sign_head: nn.Linear = nn.Linear(feature_dim, 1)
# Feature cache
self._cached_input_id: int | None = None
self._cached_features: torch.Tensor | None = None
def _get_features(self, x: torch.Tensor) -> torch.Tensor:
"""Compute (or retrieve cached) shared features.
Parameters
----------
x : torch.Tensor
Batch of configurations, shape ``(batch, num_sites)``.
Returns
-------
torch.Tensor
Feature vectors, shape ``(batch, feature_dim)``.
"""
input_id = id(x)
if self._cached_input_id == input_id and self._cached_features is not None:
return self._cached_features
encoded = self.encode_configuration(x)
features = self.feature_net(encoded)
self._cached_input_id = input_id
self._cached_features = features
return features
[docs]
def clear_feature_cache(self) -> None:
"""Clear the feature cache.
Call this between training steps or whenever the input batch
changes to avoid stale cached values.
"""
self._cached_input_id = None
self._cached_features = None
[docs]
def log_amplitude(self, x: torch.Tensor) -> torch.Tensor:
"""Compute log-amplitude from the amplitude head.
The Softplus activation ensures the raw amplitude is non-negative;
the log-amplitude is the logarithm of that value.
Parameters
----------
x : torch.Tensor
Batch of configurations, shape ``(batch, num_sites)``.
Returns
-------
torch.Tensor
Log-amplitudes, shape ``(batch,)``.
"""
features = self._get_features(x)
# Softplus output is already positive; take log for log-amplitude
amp = self.amplitude_head(features).squeeze(-1) # (batch,)
return torch.log(amp + 1e-12)
[docs]
def phase(self, x: torch.Tensor) -> torch.Tensor:
"""Compute the sign-derived phase.
The sign logit is passed through a sigmoid. Values above 0.5
correspond to a positive sign (phase = 0); values below 0.5
correspond to a negative sign (phase = pi).
During training, a soft interpolation is used for gradient flow:
``phase = pi * (1 - sigmoid(sign_logit))``.
Parameters
----------
x : torch.Tensor
Batch of configurations, shape ``(batch, num_sites)``.
Returns
-------
torch.Tensor
Phase values, shape ``(batch,)``. During training these are
continuous in ``(0, pi)``; at eval they snap to ``{0, pi}``.
"""
features = self._get_features(x)
sign_logit = self.sign_head(features).squeeze(-1) # (batch,)
if self.training:
# Soft version for gradient flow
return torch.pi * (1.0 - torch.sigmoid(sign_logit))
# Hard threshold at eval time
positive = torch.sigmoid(sign_logit) >= 0.5
return torch.where(
positive,
torch.zeros_like(sign_logit),
torch.full_like(sign_logit, torch.pi),
)