Source code for qvartools.nqs.transformer.autoregressive

"""
transformer --- Autoregressive transformer NQS with KV cache
=============================================================

Provides :class:`TransformerBlock` and :class:`AutoregressiveTransformer`,
an autoregressive transformer architecture for neural quantum states with
separate alpha and beta spin channels.  The beta channel cross-attends to
the alpha channel, enabling spin-spin correlations.

Key features:

* Pre-norm transformer blocks with optional cross-attention.
* Particle-conserving autoregressive sampling (enforces exact electron
  counts in each spin channel).
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from qvartools.nqs.transformer.attention import CausalSelfAttention, CrossAttention

__all__ = [
    "TransformerBlock",
    "AutoregressiveTransformer",
]


# ---------------------------------------------------------------------------
# TransformerBlock
# ---------------------------------------------------------------------------


class TransformerBlock(nn.Module):
    """Pre-norm transformer block with optional cross-attention.

    Architecture (pre-norm style):

    1. LayerNorm -> CausalSelfAttention -> residual
    2. (Optional) LayerNorm -> CrossAttention -> residual
    3. LayerNorm -> FFN (Linear -> GELU -> Linear) -> residual

    Parameters
    ----------
    embed_dim : int
        Embedding dimensionality.
    n_heads : int
        Number of attention heads.
    ffn_dim : int
        Hidden dimensionality of the feed-forward network.
    dropout : float, optional
        Dropout probability (default ``0.0``).
    has_cross_attn : bool, optional
        Whether to include a cross-attention sub-layer
        (default ``False``).

    Attributes
    ----------
    self_attn : CausalSelfAttention
        Causal self-attention layer.
    cross_attn : CrossAttention or None
        Cross-attention layer (only if ``has_cross_attn``).
    ffn : nn.Sequential
        Two-layer feed-forward network with GELU activation.
    """

    def __init__(
        self,
        embed_dim: int,
        n_heads: int,
        ffn_dim: int,
        dropout: float = 0.0,
        has_cross_attn: bool = False,
    ) -> None:
        super().__init__()

        # Self-attention
        self.ln_sa: nn.LayerNorm = nn.LayerNorm(embed_dim)
        self.self_attn: CausalSelfAttention = CausalSelfAttention(
            embed_dim, n_heads, dropout=dropout
        )

        # Cross-attention (optional)
        self.cross_attn: CrossAttention | None = None
        self.ln_ca: nn.LayerNorm | None = None
        if has_cross_attn:
            self.ln_ca = nn.LayerNorm(embed_dim)
            self.cross_attn = CrossAttention(embed_dim, n_heads, dropout=dropout)

        # Feed-forward network
        self.ln_ffn: nn.LayerNorm = nn.LayerNorm(embed_dim)
        self.ffn: nn.Sequential = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(
        self,
        x: torch.Tensor,
        cross_kv: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass through the transformer block.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor, shape ``(batch, seq_len, embed_dim)``.
        cross_kv : torch.Tensor or None, optional
            Key/value tensor for cross-attention, shape
            ``(batch, kv_len, embed_dim)``.  Required when
            ``has_cross_attn`` is ``True``.

        Returns
        -------
        torch.Tensor
            Output tensor, shape ``(batch, seq_len, embed_dim)``.
        """
        # Self-attention with residual
        x = x + self.self_attn(self.ln_sa(x))

        # Cross-attention with residual (optional)
        if self.cross_attn is not None and cross_kv is not None:
            assert self.ln_ca is not None
            x = x + self.cross_attn(self.ln_ca(x), cross_kv)

        # FFN with residual
        x = x + self.ffn(self.ln_ffn(x))

        return x


# ---------------------------------------------------------------------------
# AutoregressiveTransformer
# ---------------------------------------------------------------------------


[docs] class AutoregressiveTransformer(nn.Module): """Autoregressive transformer NQS with alpha/beta spin channels. Models the joint probability of occupying orbitals by factorising it autoregressively: .. math:: p(\\mathbf{x}) = \\prod_{i=1}^{N_{\\text{orb}}} p(x^\\alpha_i | x^\\alpha_{<i}) \\;\\prod_{i=1}^{N_{\\text{orb}}} p(x^\\beta_i | x^\\beta_{<i}, \\mathbf{x}^\\alpha) The alpha channel uses causal self-attention only. The beta channel uses causal self-attention *plus* cross-attention to the full alpha representation, enabling spin-spin correlations. Sampling enforces particle conservation: exactly ``n_alpha`` electrons in the alpha channel and ``n_beta`` in the beta channel. Parameters ---------- n_orbitals : int Number of spatial orbitals per spin channel. n_alpha : int Number of alpha electrons. n_beta : int Number of beta electrons. embed_dim : int, optional Embedding dimensionality (default ``64``). n_heads : int, optional Number of attention heads (default ``4``). n_layers : int, optional Number of transformer layers per channel (default ``4``). dropout : float, optional Dropout probability (default ``0.0``). Attributes ---------- alpha_blocks : nn.ModuleList Transformer blocks for the alpha channel (self-attention only). beta_blocks : nn.ModuleList Transformer blocks for the beta channel (self + cross-attention). alpha_head : nn.Linear Output head producing alpha occupation logits. beta_head : nn.Linear Output head producing beta occupation logits. Examples -------- >>> model = AutoregressiveTransformer( ... n_orbitals=6, n_alpha=2, n_beta=2, embed_dim=32, n_heads=4 ... ) >>> configs = model.sample(n_samples=16) # shape (16, 12) """ def __init__( self, n_orbitals: int, n_alpha: int, n_beta: int, embed_dim: int = 64, n_heads: int = 4, n_layers: int = 4, dropout: float = 0.0, ) -> None: super().__init__() if n_alpha > n_orbitals: raise ValueError( f"n_alpha ({n_alpha}) cannot exceed n_orbitals ({n_orbitals})." ) if n_beta > n_orbitals: raise ValueError( f"n_beta ({n_beta}) cannot exceed n_orbitals ({n_orbitals})." ) self.n_orbitals: int = n_orbitals self.n_alpha: int = n_alpha self.n_beta: int = n_beta self.embed_dim: int = embed_dim self.n_heads: int = n_heads self.n_layers: int = n_layers ffn_dim = 4 * embed_dim # Token embedding: occupation {0, 1} -> embed_dim # We use 3 tokens: 0 (unoccupied), 1 (occupied), 2 (start token) self.token_embed: nn.Embedding = nn.Embedding(3, embed_dim) # Positional embedding for each orbital position self.pos_embed_alpha: nn.Parameter = nn.Parameter( torch.randn(1, n_orbitals, embed_dim) * 0.02 ) self.pos_embed_beta: nn.Parameter = nn.Parameter( torch.randn(1, n_orbitals, embed_dim) * 0.02 ) # Alpha transformer blocks (self-attention only) self.alpha_blocks: nn.ModuleList = nn.ModuleList( [ TransformerBlock( embed_dim, n_heads, ffn_dim, dropout=dropout, has_cross_attn=False, ) for _ in range(n_layers) ] ) # Beta transformer blocks (self-attention + cross-attention to alpha) self.beta_blocks: nn.ModuleList = nn.ModuleList( [ TransformerBlock( embed_dim, n_heads, ffn_dim, dropout=dropout, has_cross_attn=True, ) for _ in range(n_layers) ] ) # Output heads self.ln_alpha: nn.LayerNorm = nn.LayerNorm(embed_dim) self.alpha_head: nn.Linear = nn.Linear(embed_dim, 1) self.ln_beta: nn.LayerNorm = nn.LayerNorm(embed_dim) self.beta_head: nn.Linear = nn.Linear(embed_dim, 1) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _embed_sequence( self, tokens: torch.Tensor, pos_embed: torch.Tensor, ) -> torch.Tensor: """Embed a token sequence with positional encoding. Parameters ---------- tokens : torch.Tensor Integer token indices, shape ``(batch, seq_len)``. pos_embed : torch.Tensor Positional embeddings, shape ``(1, max_len, embed_dim)``. Returns ------- torch.Tensor Embedded sequence, shape ``(batch, seq_len, embed_dim)``. """ seq_len = tokens.shape[1] tok_emb = self.token_embed(tokens) # (batch, seq_len, embed_dim) return tok_emb + pos_embed[:, :seq_len, :] def _run_alpha(self, alpha_tokens: torch.Tensor) -> torch.Tensor: """Run the alpha transformer stack. Parameters ---------- alpha_tokens : torch.Tensor Alpha token sequence, shape ``(batch, seq_len)``. Returns ------- torch.Tensor Alpha representations, shape ``(batch, seq_len, embed_dim)``. """ h = self._embed_sequence(alpha_tokens, self.pos_embed_alpha) for block in self.alpha_blocks: h = block(h) return h def _run_beta( self, beta_tokens: torch.Tensor, alpha_repr: torch.Tensor, ) -> torch.Tensor: """Run the beta transformer stack with cross-attention to alpha. Parameters ---------- beta_tokens : torch.Tensor Beta token sequence, shape ``(batch, seq_len)``. alpha_repr : torch.Tensor Alpha representations for cross-attention, shape ``(batch, n_orbitals, embed_dim)``. Returns ------- torch.Tensor Beta representations, shape ``(batch, seq_len, embed_dim)``. """ h = self._embed_sequence(beta_tokens, self.pos_embed_beta) for block in self.beta_blocks: h = block(h, cross_kv=alpha_repr) return h def _enable_cache(self) -> None: """Enable KV cache in all causal self-attention layers.""" for block in self.alpha_blocks: block.self_attn.enable_cache() # type: ignore[union-attr,operator] for block in self.beta_blocks: block.self_attn.enable_cache() # type: ignore[union-attr,operator] def _disable_cache(self) -> None: """Disable KV cache in all causal self-attention layers.""" for block in self.alpha_blocks: block.self_attn.disable_cache() # type: ignore[union-attr,operator] for block in self.beta_blocks: block.self_attn.disable_cache() # type: ignore[union-attr,operator] def _clear_cache(self) -> None: """Clear KV cache in all causal self-attention layers.""" for block in self.alpha_blocks: block.self_attn.clear_cache() # type: ignore[union-attr,operator] for block in self.beta_blocks: block.self_attn.clear_cache() # type: ignore[union-attr,operator] # ------------------------------------------------------------------ # Log probability # ------------------------------------------------------------------
[docs] def log_prob( self, alpha: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: """Compute the log-probability of a configuration. Parameters ---------- alpha : torch.Tensor Alpha spin-orbital occupations, shape ``(batch, n_orbitals)`` with entries in ``{0, 1}``. beta : torch.Tensor Beta spin-orbital occupations, shape ``(batch, n_orbitals)`` with entries in ``{0, 1}``. Returns ------- torch.Tensor Log-probabilities, shape ``(batch,)``. """ batch = alpha.shape[0] device = alpha.device # Build alpha input: [start_token, x_1, ..., x_{N-1}] # (teacher forcing: shift right, prepend start token) start_tok = torch.full( (batch, 1), fill_value=2, dtype=torch.long, device=device ) alpha_input = torch.cat([start_tok, alpha[:, :-1].long()], dim=1) # Alpha forward alpha_repr = self._run_alpha(alpha_input) alpha_logits = self.alpha_head(self.ln_alpha(alpha_repr)).squeeze(-1) # (batch, n_orbitals) # Alpha log-prob via binary cross-entropy alpha_log_probs = -F.binary_cross_entropy_with_logits( alpha_logits, alpha.float(), reduction="none" ) # (batch, n_orbitals) # Build beta input: [start_token, x_1, ..., x_{N-1}] beta_input = torch.cat([start_tok, beta[:, :-1].long()], dim=1) # Beta forward with cross-attention to alpha beta_repr = self._run_beta(beta_input, alpha_repr) beta_logits = self.beta_head(self.ln_beta(beta_repr)).squeeze(-1) beta_log_probs = -F.binary_cross_entropy_with_logits( beta_logits, beta.float(), reduction="none" ) # (batch, n_orbitals) # Total log-prob = sum over all orbital positions return alpha_log_probs.sum(dim=-1) + beta_log_probs.sum(dim=-1)
# ------------------------------------------------------------------ # Autoregressive sampling # ------------------------------------------------------------------
[docs] @torch.no_grad() def sample( self, n_samples: int, temperature: float = 1.0, ) -> torch.Tensor: """Generate particle-conserving configurations autoregressively. Samples alpha orbitals first (enforcing exactly ``n_alpha`` electrons), then samples beta orbitals with cross-attention to alpha (enforcing exactly ``n_beta`` electrons). The returned configuration is ``[alpha, beta]`` concatenated along the orbital axis. KV caching is used for efficient autoregressive generation. Parameters ---------- n_samples : int Number of configurations to generate. temperature : float, optional Sampling temperature. Values > 1 increase randomness; values < 1 sharpen the distribution (default ``1.0``). Returns ------- torch.Tensor Sampled configurations, shape ``(n_samples, 2 * n_orbitals)`` with entries in ``{0, 1}``. The first ``n_orbitals`` columns are alpha occupations and the last ``n_orbitals`` are beta. """ device = next(self.parameters()).device n_orb = self.n_orbitals alpha_config = torch.zeros(n_samples, n_orb, dtype=torch.long, device=device) beta_config = torch.zeros(n_samples, n_orb, dtype=torch.long, device=device) # --- Sample alpha channel --- self._enable_cache() try: alpha_config = self._sample_channel( alpha_config, n_electrons=self.n_alpha, pos_embed=self.pos_embed_alpha, blocks=self.alpha_blocks, head=self.alpha_head, ln=self.ln_alpha, temperature=temperature, cross_kv=None, ) # Get full alpha representation for beta cross-attention self._clear_cache() start_tok = torch.full( (n_samples, 1), fill_value=2, dtype=torch.long, device=device ) alpha_input = torch.cat([start_tok, alpha_config[:, :-1]], dim=1) alpha_repr = self._run_alpha(alpha_input) # --- Sample beta channel --- self._clear_cache() beta_config = self._sample_channel( beta_config, n_electrons=self.n_beta, pos_embed=self.pos_embed_beta, blocks=self.beta_blocks, head=self.beta_head, ln=self.ln_beta, temperature=temperature, cross_kv=alpha_repr, ) finally: self._disable_cache() return torch.cat([alpha_config, beta_config], dim=-1)
def _sample_channel( self, config: torch.Tensor, n_electrons: int, pos_embed: torch.Tensor, blocks: nn.ModuleList, head: nn.Linear, ln: nn.LayerNorm, temperature: float, cross_kv: torch.Tensor | None, ) -> torch.Tensor: """Autoregressively sample one spin channel with particle conservation. Parameters ---------- config : torch.Tensor Pre-allocated config tensor to fill, shape ``(n_samples, n_orbitals)``. n_electrons : int Exact number of electrons to place. pos_embed : torch.Tensor Positional embeddings for this channel. blocks : nn.ModuleList Transformer blocks for this channel. head : nn.Linear Output head producing logits. ln : nn.LayerNorm Layer norm before the output head. temperature : float Sampling temperature. cross_kv : torch.Tensor or None Cross-attention key/value from the other channel (alpha repr for beta sampling, ``None`` for alpha sampling). Returns ------- torch.Tensor Filled configuration, shape ``(n_samples, n_orbitals)``. """ n_samples = config.shape[0] n_orb = self.n_orbitals device = config.device electrons_placed = torch.zeros(n_samples, dtype=torch.long, device=device) start_tok = torch.full( (n_samples, 1), fill_value=2, dtype=torch.long, device=device ) for pos in range(n_orb): # Current token: start token for pos=0, else previous occupation if pos == 0: current_tok = start_tok else: current_tok = config[:, pos - 1 : pos] # (n_samples, 1) # Embed and add positional encoding h = self.token_embed(current_tok) + pos_embed[:, pos : pos + 1, :] # Run through transformer blocks for block in blocks: h = block(h, cross_kv=cross_kv) # Get logit for this position logit = head(ln(h)).squeeze(-1).squeeze(-1) # (n_samples,) # Apply temperature if temperature != 1.0: logit = logit / temperature # Compute occupation probability prob_occupied = torch.sigmoid(logit) # Enforce particle conservation constraints remaining_positions = n_orb - pos electrons_needed = n_electrons - electrons_placed # Must occupy: not enough remaining positions for remaining electrons must_occupy = electrons_needed >= remaining_positions # Cannot occupy: already placed all electrons cannot_occupy = electrons_needed <= 0 # Clamp probabilities prob_occupied = torch.where( must_occupy, torch.ones_like(prob_occupied), prob_occupied, ) prob_occupied = torch.where( cannot_occupy, torch.zeros_like(prob_occupied), prob_occupied, ) # Sample occupation = torch.bernoulli(prob_occupied).long() config[:, pos] = occupation electrons_placed = electrons_placed + occupation return config
[docs] def forward( self, alpha: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: """Forward pass --- delegates to :meth:`log_prob`. Parameters ---------- alpha : torch.Tensor Alpha occupations, shape ``(batch, n_orbitals)``. beta : torch.Tensor Beta occupations, shape ``(batch, n_orbitals)``. Returns ------- torch.Tensor Log-probabilities, shape ``(batch,)``. """ return self.log_prob(alpha, beta)