Source code for qvartools.nqs.architectures.complex_nqs

"""
complex_nqs --- Complex-valued dense NQS architecture
=====================================================

Provides :class:`ComplexNQS`, a dense NQS with shared feature extractor
and separate amplitude/phase heads.
"""

from __future__ import annotations

import torch
import torch.nn as nn

from qvartools.nqs.architectures.rbm import RBMQuantumState  # noqa: F401
from qvartools.nqs.neural_state import NeuralQuantumState

__all__ = [
    "ComplexNQS",
    "RBMQuantumState",
]


# ---------------------------------------------------------------------------
# ComplexNQS
# ---------------------------------------------------------------------------


[docs] class ComplexNQS(NeuralQuantumState): """Complex-valued NQS with shared feature extractor. A shared MLP backbone produces feature vectors that are fed into two independent heads: * **Amplitude head** --- maps features to a scalar log-amplitude. * **Phase head** --- maps features to a phase in ``(-pi, pi)``. Feature caching avoids redundant computation when :meth:`log_amplitude` and :meth:`phase` are called sequentially on the same input tensor. Parameters ---------- num_sites : int Number of lattice / orbital sites. hidden_dims : list of int, optional Hidden-layer sizes for the shared feature extractor (default ``[128, 64]``). Attributes ---------- feature_net : nn.Sequential Shared feature MLP. amplitude_head : nn.Linear Linear projection from features to scalar log-amplitude. phase_head : nn.Sequential Maps features to phase via Linear + Tanh (scaled by pi). Examples -------- >>> nqs = ComplexNQS(num_sites=10, hidden_dims=[64, 32]) >>> x = torch.randint(0, 2, (8, 10)).float() >>> log_amp = nqs.log_amplitude(x) # shape (8,) >>> phi = nqs.phase(x) # shape (8,), in (-pi, pi) """ def __init__( self, num_sites: int, hidden_dims: list[int] | None = None, ) -> None: super().__init__( num_sites=num_sites, local_dim=2, complex_output=True, ) if hidden_dims is None: hidden_dims = [128, 64] # Build shared feature extractor layers: list[nn.Module] = [] prev_dim = num_sites for h_dim in hidden_dims: layers.append(nn.Linear(prev_dim, h_dim)) layers.append(nn.ReLU()) prev_dim = h_dim self.feature_net: nn.Sequential = nn.Sequential(*layers) feature_dim = hidden_dims[-1] # Amplitude head: single linear -> scalar self.amplitude_head: nn.Linear = nn.Linear(feature_dim, 1) # Phase head: linear -> tanh (output in (-1,1), scaled by pi) self.phase_head: nn.Sequential = nn.Sequential( nn.Linear(feature_dim, 1), nn.Tanh(), ) # Feature cache self._cached_input_id: int | None = None self._cached_features: torch.Tensor | None = None def _get_features(self, x: torch.Tensor) -> torch.Tensor: """Compute (or retrieve cached) shared features. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Feature vectors, shape ``(batch, feature_dim)``. """ input_id = id(x) if self._cached_input_id == input_id and self._cached_features is not None: return self._cached_features encoded = self.encode_configuration(x) features = self.feature_net(encoded) self._cached_input_id = input_id self._cached_features = features return features
[docs] def clear_feature_cache(self) -> None: """Clear the feature cache. Call this between training steps or when the input batch changes to avoid stale cached values. """ self._cached_input_id = None self._cached_features = None
[docs] def log_amplitude(self, x: torch.Tensor) -> torch.Tensor: """Compute log-amplitude from shared features. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Log-amplitudes, shape ``(batch,)``. """ features = self._get_features(x) return self.amplitude_head(features).squeeze(-1)
[docs] def phase(self, x: torch.Tensor) -> torch.Tensor: """Compute phase from shared features. The Tanh output (in ``(-1, 1)``) is scaled by ``pi`` to produce a phase in ``(-pi, pi)``. Parameters ---------- x : torch.Tensor Batch of configurations, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Phases in radians, shape ``(batch,)``. """ features = self._get_features(x) raw = self.phase_head(features).squeeze(-1) return raw * torch.pi