Source code for qvartools.flows.networks.discrete_flow

"""
discrete_flow --- RealNVP normalizing flow for discrete configurations
======================================================================

Implements a RealNVP-style normalizing flow that maps samples from a
multi-modal Gaussian prior through a sequence of affine coupling layers
to produce continuous outputs, which are then discretised into binary
configurations via thresholding.

The multi-modal prior (mixture of Gaussians centred at +/- 1) ensures
uniform coverage of both ``{0, 1}`` values after discretisation.
"""

from __future__ import annotations

import logging
import math

import torch
import torch.nn as nn

from qvartools.flows.networks.coupling_network import CouplingNetwork, MultiModalPrior

__all__ = [
    "DiscreteFlowSampler",
]

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# DiscreteFlowSampler
# ---------------------------------------------------------------------------


[docs] class DiscreteFlowSampler(nn.Module): """RealNVP normalizing flow mapping continuous latent to discrete configs. Uses alternating binary masks to split dimensions across coupling layers. The multi-modal prior ensures uniform coverage of both ``{0, 1}`` values. Continuous outputs are clamped to ``[-1, 1]`` and discretised by thresholding at zero. Parameters ---------- num_sites : int Number of binary sites in each configuration. num_coupling_layers : int, optional Number of affine coupling layers (default ``6``). hidden_dims : list of int, optional Hidden-layer sizes for the coupling networks (default ``[128, 128]``). prior_std : float, optional Standard deviation of the mixture-of-Gaussians prior components (default ``1.0``). n_mc_samples : int, optional Number of Monte Carlo samples for discrete log-probability estimation (default ``100``). Attributes ---------- num_sites : int Number of sites. num_coupling_layers : int Number of coupling layers. prior : MultiModalPrior The mixture-of-Gaussians prior. masks : list of torch.Tensor Binary masks for each coupling layer. coupling_nets : nn.ModuleList Coupling networks for each layer. n_mc_samples : int Number of MC samples for discrete probability estimation. Examples -------- >>> flow = DiscreteFlowSampler(num_sites=10, num_coupling_layers=4) >>> configs, unique = flow.sample(batch_size=256) >>> configs.shape torch.Size([256, 10]) """ def __init__( self, num_sites: int, num_coupling_layers: int = 6, hidden_dims: list[int] | None = None, prior_std: float = 1.0, n_mc_samples: int = 100, ) -> None: super().__init__() if num_sites < 1: raise ValueError(f"num_sites must be >= 1, got {num_sites}") if num_coupling_layers < 1: raise ValueError( f"num_coupling_layers must be >= 1, got {num_coupling_layers}" ) if hidden_dims is None: hidden_dims = [128, 128] self.num_sites: int = num_sites self.num_coupling_layers: int = num_coupling_layers self.n_mc_samples: int = n_mc_samples self.prior: MultiModalPrior = MultiModalPrior( num_sites=num_sites, std=prior_std ) # Build alternating masks: even layers mask first half, # odd layers mask second half. self.masks: list[torch.Tensor] = [] coupling_nets: list[CouplingNetwork] = [] for layer_idx in range(num_coupling_layers): mask = torch.zeros(num_sites) if layer_idx % 2 == 0: mask[: num_sites // 2] = 1.0 else: mask[num_sites // 2 :] = 1.0 self.masks.append(mask) masked_dim = int(mask.sum().item()) unmasked_dim = num_sites - masked_dim coupling_nets.append( CouplingNetwork( input_dim=masked_dim, hidden_dims=hidden_dims, output_dim=unmasked_dim, ) ) self.coupling_nets: nn.ModuleList = nn.ModuleList(coupling_nets) # Register masks as buffers so they move with the model for i, mask in enumerate(self.masks): self.register_buffer(f"mask_{i}", mask) def _get_mask(self, layer_idx: int) -> torch.Tensor: """Retrieve the mask buffer for a given layer. Parameters ---------- layer_idx : int Index of the coupling layer. Returns ------- torch.Tensor Binary mask of shape ``(num_sites,)``. """ return getattr(self, f"mask_{layer_idx}") def _forward_flow(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Run the forward (generative) direction of the flow. Parameters ---------- z : torch.Tensor Latent samples from the prior, shape ``(batch, num_sites)``. Returns ------- y : torch.Tensor Transformed samples, shape ``(batch, num_sites)``. log_det_jacobian : torch.Tensor Sum of log absolute determinant of the Jacobian for each sample, shape ``(batch,)``. """ y = z log_det = torch.zeros(z.shape[0], device=z.device) for layer_idx in range(self.num_coupling_layers): mask = self._get_mask(layer_idx) mask_b = mask.bool() # Split into masked (fixed) and unmasked (transformed) parts masked_x = y[:, mask_b] unmasked_x = y[:, ~mask_b] scale, shift = self.coupling_nets[layer_idx](masked_x) # Affine transform: y_unmasked = unmasked * exp(scale) + shift transformed = unmasked_x * torch.exp(scale) + shift log_det = log_det + scale.sum(dim=-1) # Reassemble y_new = torch.empty_like(y) y_new[:, mask_b] = masked_x y_new[:, ~mask_b] = transformed y = y_new return y, log_det def _inverse_flow(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Run the inverse (inference) direction of the flow. Parameters ---------- y : torch.Tensor Data-space samples, shape ``(batch, num_sites)``. Returns ------- z : torch.Tensor Latent-space samples, shape ``(batch, num_sites)``. log_det_jacobian : torch.Tensor Sum of log |det J^{-1}| for each sample, shape ``(batch,)``. """ z = y log_det = torch.zeros(y.shape[0], device=y.device) for layer_idx in range(self.num_coupling_layers - 1, -1, -1): mask = self._get_mask(layer_idx) mask_b = mask.bool() masked_x = z[:, mask_b] unmasked_x = z[:, ~mask_b] scale, shift = self.coupling_nets[layer_idx](masked_x) # Inverse affine: z_unmasked = (unmasked - shift) * exp(-scale) inv_transformed = (unmasked_x - shift) * torch.exp(-scale) log_det = log_det - scale.sum(dim=-1) z_new = torch.empty_like(z) z_new[:, mask_b] = masked_x z_new[:, ~mask_b] = inv_transformed z = z_new return z, log_det
[docs] def sample_continuous(self, batch_size: int) -> torch.Tensor: """Sample continuous outputs from the flow, clamped to [-1, 1]. Parameters ---------- batch_size : int Number of samples to draw. Returns ------- torch.Tensor Continuous samples clamped to ``[-1, 1]``, shape ``(batch_size, num_sites)``. """ device = next(self.parameters()).device self.prior.device = device z = self.prior.sample(batch_size) y, _ = self._forward_flow(z) return torch.clamp(y, -1.0, 1.0)
[docs] @staticmethod def discretize(y: torch.Tensor) -> torch.Tensor: """Discretise continuous outputs to binary {0, 1} by thresholding. Values at or above zero are mapped to 1; values below zero are mapped to 0. Parameters ---------- y : torch.Tensor Continuous tensor, shape ``(..., num_sites)``. Returns ------- torch.Tensor Binary tensor with values in ``{0, 1}``, same shape as *y*, dtype ``torch.float32``. """ return (y >= 0.0).float()
[docs] def sample(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: """Sample discrete binary configurations. Draws continuous samples from the flow, discretises them, and returns both the full batch and the unique configurations. Parameters ---------- batch_size : int Number of samples to draw. Returns ------- configs : torch.Tensor All discrete configurations, shape ``(batch_size, num_sites)``. unique_configs : torch.Tensor Unique configurations, shape ``(n_unique, num_sites)`` where ``n_unique <= batch_size``. """ y = self.sample_continuous(batch_size) configs = self.discretize(y) unique_configs = torch.unique(configs, dim=0) return configs, unique_configs
[docs] def log_prob_continuous(self, y: torch.Tensor) -> torch.Tensor: """Compute log-probability in continuous space via change of variables. Uses the inverse flow to map data-space samples back to the prior, then applies the change-of-variables formula: ``log p(y) = log p_prior(z) + log |det J^{-1}|``. Parameters ---------- y : torch.Tensor Continuous data-space samples, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Log-probabilities, shape ``(batch,)``. """ z, log_det_inv = self._inverse_flow(y) self.prior.device = y.device log_pz = self.prior.log_prob(z) return log_pz + log_det_inv
[docs] def log_prob_discrete(self, x: torch.Tensor) -> torch.Tensor: """Estimate discrete log-probability via Monte Carlo integration. For each discrete configuration ``x`` in ``{0, 1}^n``, the probability is estimated by integrating the continuous density over the corresponding Voronoi cell (``[-1, 0)`` for 0 and ``[0, 1]`` for 1). This is done by sampling uniform noise within each cell and averaging the continuous density. Parameters ---------- x : torch.Tensor Discrete configurations with values in ``{0, 1}``, shape ``(batch, num_sites)``. Returns ------- torch.Tensor Estimated log-probabilities, shape ``(batch,)``. """ batch = x.shape[0] device = x.device n_mc = self.n_mc_samples # Expand x for MC samples: (batch, n_mc, num_sites) x_expanded = x.unsqueeze(1).expand(batch, n_mc, self.num_sites) # Uniform noise within the Voronoi cell of each discrete value # site == 0 -> sample from [-1, 0), site == 1 -> sample from [0, 1] noise = torch.rand(batch, n_mc, self.num_sites, device=device) y_mc = torch.where( x_expanded == 1.0, noise, # [0, 1] for site == 1 noise - 1.0, # [-1, 0) for site == 0 ) # Flatten for log_prob_continuous: (batch * n_mc, num_sites) y_flat = y_mc.reshape(batch * n_mc, self.num_sites) log_probs = self.log_prob_continuous(y_flat) log_probs = log_probs.reshape(batch, n_mc) # Monte Carlo estimate: log E[p(y)] = logsumexp(log p(y)) - log(n_mc) log_prob_est = torch.logsumexp(log_probs, dim=1) - math.log(n_mc) return log_prob_est
[docs] def forward( self, batch_size: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass: sample and compute log-probabilities. Parameters ---------- batch_size : int Number of samples to draw. Returns ------- configs : torch.Tensor Discrete configurations, shape ``(batch_size, num_sites)``. unique_configs : torch.Tensor Unique configurations, shape ``(n_unique, num_sites)``. log_probs : torch.Tensor Continuous log-probabilities at the pre-discretisation points, shape ``(batch_size,)``. """ device = next(self.parameters()).device self.prior.device = device z = self.prior.sample(batch_size) y, log_det = self._forward_flow(z) y_clamped = torch.clamp(y, -1.0, 1.0) configs = self.discretize(y_clamped) unique_configs = torch.unique(configs, dim=0) log_pz = self.prior.log_prob(z) log_probs = log_pz + log_det return configs, unique_configs, log_probs