Source code for qvartools.samplers.sampler

"""
base --- Abstract sampler interface and result dataclass
========================================================

Defines the :class:`Sampler` ABC and the :class:`SamplerResult` immutable
dataclass that standardise sampler outputs across all sampling strategies.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any

import torch

__all__ = [
    "SamplerResult",
    "Sampler",
]


# ---------------------------------------------------------------------------
# SamplerResult dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class SamplerResult: """Immutable container for sampler output. Parameters ---------- configs : torch.Tensor Sampled configurations, shape ``(n_samples, n_sites)``. Each row is a computational-basis state with entries in ``{0, 1}``. counts : dict Mapping from bitstring representation to occurrence count. Keys are strings of ``'0'`` and ``'1'`` characters. metadata : dict Additional sampler-specific information (e.g. unique count, sampling time, flow parameters). log_probs : torch.Tensor or None, optional Log-probabilities of the sampled configurations, shape ``(n_samples,)``. ``None`` when the sampler does not provide probabilities. wall_time : float, optional Wall-clock time in seconds spent generating the samples (default ``0.0``). Attributes ---------- configs : torch.Tensor counts : dict metadata : dict log_probs : torch.Tensor or None wall_time : float Examples -------- >>> result = SamplerResult(configs=torch.zeros(10, 4, dtype=torch.int64)) >>> result.configs.shape torch.Size([10, 4]) """ configs: torch.Tensor counts: dict[str, int] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) log_probs: torch.Tensor | None = None wall_time: float = 0.0
# --------------------------------------------------------------------------- # Sampler ABC # ---------------------------------------------------------------------------
[docs] class Sampler(ABC): """Abstract base class for all configuration samplers. Every subclass must implement :meth:`sample`, which draws a set of computational-basis configurations and returns a :class:`SamplerResult`. """
[docs] @abstractmethod def sample(self, n_samples: int) -> SamplerResult: """Draw configuration samples. Parameters ---------- n_samples : int Number of samples to draw. Returns ------- SamplerResult Sampled configurations, counts, and metadata. """