Source code for qvartools.diag.eigen.projected_hamiltonian

"""
projected_hamiltonian --- Build projected Hamiltonian in a sampled basis
========================================================================

Constructs the projected Hamiltonian matrix H_ij = <x_i|H|x_j> in a
sampled computational-basis subset.  Uses the Hamiltonian's
:meth:`~qvartools.hamiltonians.hamiltonian.Hamiltonian.get_connections` interface
for sparse construction and hash-based lookup for O(1) basis membership
testing.

Classes
-------
ProjectedHamiltonianConfig
    Hyperparameters controlling the projected construction.
ProjectedHamiltonianBuilder
    Builds the sparse projected Hamiltonian from a basis set.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass

import scipy.sparse
import torch

from qvartools.hamiltonians.hamiltonian import Hamiltonian

__all__ = [
    "ProjectedHamiltonianConfig",
    "ProjectedHamiltonianBuilder",
]

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class ProjectedHamiltonianConfig: """Hyperparameters for projected Hamiltonian construction. Parameters ---------- use_sparse : bool If ``True`` (default), build a sparse CSR matrix. If ``False``, build a dense NumPy array (only practical for small bases). batch_size : int Number of basis states to process per batch when constructing the matrix. Larger batches use more memory but may be faster. Examples -------- >>> cfg = ProjectedHamiltonianConfig(use_sparse=False) >>> cfg.use_sparse False """ use_sparse: bool = True batch_size: int = 1000 def __post_init__(self) -> None: """Validate parameters after initialization.""" if self.batch_size < 1: raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
# --------------------------------------------------------------------------- # Hash-based index lookup # --------------------------------------------------------------------------- def _build_config_index(basis_states: torch.Tensor) -> dict[int, int]: """Build a hash map from configuration hash to basis index. Each configuration is converted to a unique integer hash by treating the binary vector as digits in a base equal to the maximum value + 1. Parameters ---------- basis_states : torch.Tensor Basis configurations, shape ``(n_basis, n_sites)``. Returns ------- dict Mapping from configuration hash to row/column index. """ n_basis, n_sites = basis_states.shape # Use a large prime-based hash to avoid collisions # Treat each config as a number in base (max_val+1) max_val = int(basis_states.max().item()) + 1 multipliers = torch.tensor( [max_val**i for i in range(n_sites - 1, -1, -1)], dtype=torch.int64, device=basis_states.device, ) hashes = (basis_states.to(torch.int64) * multipliers.unsqueeze(0)).sum(dim=1) return {int(hashes[i].item()): i for i in range(n_basis)} def _config_hash(config: torch.Tensor, max_val: int) -> int: """Compute the hash of a single configuration. Parameters ---------- config : torch.Tensor Configuration vector, shape ``(n_sites,)``. max_val : int Base for the positional encoding (typically ``max(config) + 1`` or ``2`` for binary). Returns ------- int Integer hash of the configuration. """ n_sites = config.shape[0] h = 0 for i in range(n_sites): h = h * max_val + int(config[i].item()) return h # --------------------------------------------------------------------------- # Builder # ---------------------------------------------------------------------------
[docs] class ProjectedHamiltonianBuilder: """Build the projected Hamiltonian H_ij = <x_i|H|x_j> in a sampled basis. Uses the Hamiltonian's ``get_connections()`` method for efficient sparse construction: for each basis state |x_j>, the connected states and their matrix elements are retrieved, and a hash-based lookup determines which connections lie within the sampled basis. Parameters ---------- hamiltonian : Hamiltonian The Hamiltonian operator (must implement ``diagonal_element`` and ``get_connections``). config : ProjectedHamiltonianConfig or None, optional Construction hyperparameters. If ``None``, defaults are used. Examples -------- >>> from qvartools.hamiltonians.spin import HeisenbergHamiltonian >>> ham = HeisenbergHamiltonian(num_sites=6, J=1.0) >>> builder = ProjectedHamiltonianBuilder(ham) >>> basis = torch.tensor([[1,0,1,0,1,0], [0,1,0,1,0,1]]) >>> H_proj = builder.build(basis) >>> H_proj.shape (2, 2) """ def __init__( self, hamiltonian: Hamiltonian, config: ProjectedHamiltonianConfig | None = None, ) -> None: self._hamiltonian = hamiltonian self._config = config if config is not None else ProjectedHamiltonianConfig()
[docs] def build( self, basis_states: torch.Tensor, ) -> scipy.sparse.csr_matrix: """Construct the projected Hamiltonian matrix. Uses vectorised batch diagonal computation and hash-based connection matching for efficiency. Parameters ---------- basis_states : torch.Tensor Sampled basis configurations, shape ``(n_basis, n_sites)``. Each row is a computational-basis state. Returns ------- scipy.sparse.csr_matrix Projected Hamiltonian of shape ``(n_basis, n_basis)``. Raises ------ ValueError If ``basis_states`` has fewer than 1 row or the wrong number of columns. """ n_basis, n_sites = basis_states.shape if n_basis < 1: raise ValueError("basis_states must contain at least one state.") if n_sites != self._hamiltonian.num_sites: raise ValueError( f"basis_states has {n_sites} sites, but Hamiltonian expects " f"{self._hamiltonian.num_sites}." ) logger.info( "Building projected Hamiltonian: %d basis states, %d sites", n_basis, n_sites, ) # Build hash-based index for O(1) membership lookup max_val = max(int(basis_states.max().item()) + 1, 2) config_index = _build_config_index(basis_states) rows: list[int] = [] cols: list[int] = [] vals: list[float] = [] # --- Vectorised diagonal computation (batch, no Python loop) --- if hasattr(self._hamiltonian, "diagonal_elements_batch"): diag_all = self._hamiltonian.diagonal_elements_batch(basis_states) for j in range(n_basis): diag_float = float(diag_all[j].item()) if abs(diag_float) > 0: rows.append(j) cols.append(j) vals.append(diag_float) else: # Fallback for Hamiltonians without batch diagonal for j in range(n_basis): diag_val = self._hamiltonian.diagonal_element(basis_states[j]) diag_float = float( diag_val.real if hasattr(diag_val, "real") else diag_val ) if abs(diag_float) > 0: rows.append(j) cols.append(j) vals.append(diag_float) # --- Off-diagonal: connected states --- # Pre-move basis to CPU once for Numba-based get_connections basis_cpu = basis_states.detach().cpu() # Pre-sort hashes for vectorised searchsorted matching hash_values = torch.tensor(list(config_index.keys()), dtype=torch.int64) hash_indices = torch.tensor(list(config_index.values()), dtype=torch.int64) sorted_order = torch.argsort(hash_values) sorted_hash_keys = hash_values[sorted_order] sorted_hash_vals = hash_indices[sorted_order] n_hash = sorted_hash_keys.shape[0] for j in range(n_basis): connected, elements = self._hamiltonian.get_connections(basis_cpu[j]) if connected.numel() == 0: continue # Vectorised hash matching via searchsorted if hasattr(self._hamiltonian, "_config_hash_batch"): conn_hashes = self._hamiltonian._config_hash_batch(connected) else: conn_hashes = torch.tensor( [ _config_hash(connected[c], max_val) for c in range(connected.shape[0]) ], dtype=torch.int64, ) positions = torch.searchsorted(sorted_hash_keys, conn_hashes) positions = positions.clamp(max=n_hash - 1) matched_mask = sorted_hash_keys[positions] == conn_hashes if matched_mask.any(): matched_i = sorted_hash_vals[positions[matched_mask]] matched_els = elements[matched_mask] for k in range(matched_i.shape[0]): mel = float(matched_els[k].item()) if abs(mel) > 0: rows.append(int(matched_i[k].item())) cols.append(j) vals.append(mel) if j > 0 and j % 1000 == 0: logger.debug( "Processed %d / %d states (%d non-zeros so far)", j, n_basis, len(vals), ) H_proj = scipy.sparse.csr_matrix( (vals, (rows, cols)), shape=(n_basis, n_basis), ) logger.info( "Projected Hamiltonian: shape (%d, %d), nnz=%d, density=%.4f", n_basis, n_basis, H_proj.nnz, H_proj.nnz / (n_basis * n_basis) if n_basis > 0 else 0, ) return H_proj