Source code for qvartools.diag.selection.diversity_selection

"""
diversity_selection --- Diversity-aware basis configuration selection
====================================================================

Implements excitation-rank bucketing, Hamming-distance filtering, and
optional determinantal point process (DPP) selection to build a compact,
diverse subset of computational-basis configurations for projected
eigenvalue problems.

Classes
-------
DiversityConfig
    Dataclass holding all hyperparameters for diversity selection.
DiversitySelector
    Stateful selector that buckets, ranks, and filters configurations.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass

import torch

from qvartools.diag.selection.excitation_rank import (  # noqa: F401
    bitpack_configs,
    bitpacked_hamming,
    compute_excitation_rank,
    compute_hamming_distance,
)

__all__ = [
    "DiversityConfig",
    "DiversitySelector",
    "compute_excitation_rank",
    "compute_hamming_distance",
    "bitpack_configs",
]

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Configuration dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class DiversityConfig: """Hyperparameters for diversity-aware configuration selection. The selector partitions configurations by excitation rank relative to a reference state and allocates per-rank quotas according to the given fractions. Within each rank bucket, configurations are selected by NQS importance weight and filtered to enforce a minimum pairwise Hamming distance. Parameters ---------- max_configs : int Maximum number of configurations to select. rank_0_fraction : float Fraction of ``max_configs`` allocated to rank-0 (reference) configs. rank_1_fraction : float Fraction allocated to single-excitation configurations. rank_2_fraction : float Fraction allocated to double-excitation configurations. rank_3_fraction : float Fraction allocated to triple-excitation configurations. rank_4_plus_fraction : float Fraction allocated to quadruple-and-higher excitations. min_hamming_distance : int Minimum pairwise Hamming distance between any two selected configs. use_dpp_selection : bool If ``True``, use a determinantal point process kernel for maximal diversity within each rank bucket (more expensive). dpp_kernel_scale : float Length-scale parameter for the DPP Gaussian kernel. Examples -------- >>> cfg = DiversityConfig(max_configs=500, min_hamming_distance=3) >>> cfg.max_configs 500 """ max_configs: int = 1000 rank_0_fraction: float = 0.05 rank_1_fraction: float = 0.15 rank_2_fraction: float = 0.40 rank_3_fraction: float = 0.25 rank_4_plus_fraction: float = 0.15 min_hamming_distance: int = 2 use_dpp_selection: bool = False dpp_kernel_scale: float = 0.1 def __post_init__(self) -> None: """Validate fraction constraints after initialization.""" total = ( self.rank_0_fraction + self.rank_1_fraction + self.rank_2_fraction + self.rank_3_fraction + self.rank_4_plus_fraction ) if abs(total - 1.0) > 1e-6: raise ValueError(f"Rank fractions must sum to 1.0, got {total:.6f}") if self.max_configs < 1: raise ValueError(f"max_configs must be >= 1, got {self.max_configs}") if self.min_hamming_distance < 0: raise ValueError( f"min_hamming_distance must be >= 0, got {self.min_hamming_distance}" )
# --------------------------------------------------------------------------- # Diversity selector # ---------------------------------------------------------------------------
[docs] class DiversitySelector: """Select a diverse, representative subset of configurations. Configurations are bucketed by excitation rank relative to a reference state (e.g. Hartree-Fock). Within each bucket, the top-weighted configurations are greedily selected while enforcing a minimum pairwise Hamming distance. An optional DPP kernel can be applied for maximal coverage. Parameters ---------- config : DiversityConfig Selection hyperparameters. reference : torch.Tensor Reference configuration (e.g. Hartree-Fock ground state), shape ``(n_orbitals,)`` with binary entries. n_orbitals : int Number of orbitals (must match ``reference.shape[0]``). Raises ------ ValueError If ``reference`` length does not match ``n_orbitals``. Examples -------- >>> cfg = DiversityConfig(max_configs=100) >>> ref = torch.tensor([1, 1, 1, 0, 0, 0]) >>> selector = DiversitySelector(cfg, ref, n_orbitals=6) >>> selected, stats = selector.select(all_configs, weights) """ def __init__( self, config: DiversityConfig, reference: torch.Tensor, n_orbitals: int, ) -> None: if reference.shape[0] != n_orbitals: raise ValueError( f"Reference length {reference.shape[0]} does not match " f"n_orbitals={n_orbitals}" ) self._config = config self._reference = reference.clone() self._n_orbitals = n_orbitals
[docs] def select( self, configs: torch.Tensor, weights: torch.Tensor | None = None, ) -> tuple[torch.Tensor, dict]: """Select a diverse subset of configurations. Parameters ---------- configs : torch.Tensor Pool of candidate configurations, shape ``(n_pool, n_orbitals)`` with binary entries. weights : torch.Tensor or None, optional NQS importance weights for each configuration, shape ``(n_pool,)``. If ``None``, uniform weights are used. Returns ------- selected : torch.Tensor Selected configurations, shape ``(n_selected, n_orbitals)``. stats : dict Selection statistics with keys: - ``"total_pool"`` : int -- size of the input pool. - ``"n_selected"`` : int -- number of selected configurations. - ``"rank_counts"`` : dict -- count per excitation rank in pool. - ``"rank_selected"`` : dict -- count per rank in selection. - ``"rank_quotas"`` : dict -- allocated quota per rank. """ n_pool = configs.shape[0] if weights is None: weights = torch.ones(n_pool, dtype=torch.float64, device=configs.device) # --- Compute excitation ranks for all configs (vectorized) --- ranks = (configs != self._reference.unsqueeze(0)).sum(dim=1) # --- Build rank buckets --- rank_fractions = { 0: self._config.rank_0_fraction, 1: self._config.rank_1_fraction, 2: self._config.rank_2_fraction, 3: self._config.rank_3_fraction, } # Rank 4+ is a catch-all bucket max_cfg = self._config.max_configs rank_quotas: dict[int, int] = {} for rank_val, frac in rank_fractions.items(): rank_quotas[rank_val] = max(1, int(frac * max_cfg)) rank_quotas[4] = max(1, int(self._config.rank_4_plus_fraction * max_cfg)) # Adjust quotas so they sum to max_configs total_allocated = sum(rank_quotas.values()) if total_allocated != max_cfg: # Add/subtract difference from the largest bucket (rank 2) rank_quotas[2] = rank_quotas[2] + (max_cfg - total_allocated) # --- Bucket and select --- selected_indices: list[int] = [] rank_pool_counts: dict[int, int] = {} rank_selected_counts: dict[int, int] = {} # Bit-pack for fast Hamming if we have a minimum distance constraint packed = ( bitpack_configs(configs) if self._config.min_hamming_distance > 0 else None ) for rank_key, quota in sorted(rank_quotas.items()): if rank_key < 4: bucket_mask = ranks == rank_key else: bucket_mask = ranks >= 4 bucket_indices = bucket_mask.nonzero(as_tuple=False).squeeze(-1) rank_pool_counts[rank_key] = int(bucket_indices.shape[0]) if bucket_indices.shape[0] == 0: rank_selected_counts[rank_key] = 0 continue # Sort by weight descending within bucket bucket_weights = weights[bucket_indices] sorted_order = torch.argsort(bucket_weights, descending=True) sorted_bucket = bucket_indices[sorted_order] if self._config.use_dpp_selection and packed is not None: bucket_selected = self._dpp_select(sorted_bucket, packed, quota) else: bucket_selected = self._greedy_select(sorted_bucket, packed, quota) selected_indices.extend(bucket_selected) rank_selected_counts[rank_key] = len(bucket_selected) if len(selected_indices) == 0: logger.warning("No configurations selected; returning empty tensor.") empty = torch.empty( 0, self._n_orbitals, dtype=configs.dtype, device=configs.device ) stats = { "total_pool": n_pool, "n_selected": 0, "rank_counts": rank_pool_counts, "rank_selected": rank_selected_counts, "rank_quotas": dict(rank_quotas), } return empty, stats idx_tensor = torch.tensor( selected_indices, dtype=torch.long, device=configs.device ) selected = configs[idx_tensor] stats = { "total_pool": n_pool, "n_selected": len(selected_indices), "rank_counts": rank_pool_counts, "rank_selected": rank_selected_counts, "rank_quotas": dict(rank_quotas), } logger.info( "Diversity selection: %d / %d configs selected (quotas: %s)", stats["n_selected"], stats["total_pool"], stats["rank_quotas"], ) return selected, stats
def _greedy_select( self, sorted_indices: torch.Tensor, packed: torch.Tensor | None, quota: int, ) -> list[int]: """Greedily select configs by weight, enforcing minimum Hamming distance. Parameters ---------- sorted_indices : torch.Tensor Candidate indices sorted by weight (descending), shape ``(n,)``. packed : torch.Tensor or None Bit-packed configurations for fast Hamming, or ``None`` if no distance constraint. quota : int Maximum number of configurations to select from this bucket. Returns ------- list of int Selected global indices. """ selected: list[int] = [] min_dist = self._config.min_hamming_distance for i in range(sorted_indices.shape[0]): if len(selected) >= quota: break candidate = int(sorted_indices[i].item()) if min_dist > 0 and packed is not None and len(selected) > 0: # Check Hamming distance against all already-selected cand_idx = torch.tensor( [candidate] * len(selected), dtype=torch.long, device=packed.device, ) sel_idx = torch.tensor( selected, dtype=torch.long, device=packed.device, ) distances = bitpacked_hamming(packed, cand_idx, sel_idx) if distances.min().item() < min_dist: continue selected.append(candidate) return selected def _dpp_select( self, sorted_indices: torch.Tensor, packed: torch.Tensor, quota: int, ) -> list[int]: """Select configs using a DPP-inspired greedy kernel for maximal coverage. Uses a Gaussian kernel over Hamming distances to define a quality- diversity trade-off. The selection greedily picks the item that maximises the log-determinant of the kernel sub-matrix. Parameters ---------- sorted_indices : torch.Tensor Candidate indices sorted by weight (descending), shape ``(n,)``. packed : torch.Tensor Bit-packed configurations for fast Hamming computation. quota : int Maximum number of configurations to select. Returns ------- list of int Selected global indices. """ n_candidates = min(sorted_indices.shape[0], quota * 5) candidates = sorted_indices[:n_candidates] n = candidates.shape[0] if n == 0: return [] # Build pairwise Hamming distance matrix for candidates row_idx = torch.arange(n, device=packed.device).repeat_interleave(n) col_idx = torch.arange(n, device=packed.device).repeat(n) cand_a = candidates[row_idx] cand_b = candidates[col_idx] dists = bitpacked_hamming(packed, cand_a, cand_b).float().reshape(n, n) # Gaussian kernel: K_ij = exp(-d_ij^2 / (2 * scale^2)) scale = self._config.dpp_kernel_scale * self._n_orbitals kernel = torch.exp(-dists.pow(2) / (2.0 * scale * scale)) # Greedy DPP selection (approximate log-det maximization) selected: list[int] = [] remaining = set(range(n)) for _ in range(min(quota, n)): best_idx = -1 best_score = -float("inf") for r in remaining: if len(selected) == 0: score = float(kernel[r, r].item()) else: sel_tensor = torch.tensor( selected, dtype=torch.long, device=kernel.device ) k_ss = kernel[sel_tensor][:, sel_tensor] k_new_row = kernel[r, sel_tensor].unsqueeze(0) k_new_diag = kernel[r, r].unsqueeze(0).unsqueeze(0) augmented = torch.cat( [ torch.cat([k_ss, k_new_row.T], dim=1), torch.cat([k_new_row, k_new_diag], dim=1), ], dim=0, ) sign, logabsdet = torch.linalg.slogdet(augmented) score = ( float(logabsdet.item()) if sign.item() > 0 else -float("inf") ) if score > best_score: best_score = score best_idx = r if best_idx < 0: break selected.append(best_idx) remaining.discard(best_idx) # Map local indices back to global return [int(candidates[s].item()) for s in selected]