Source code for qvartools.krylov.expansion.krylov_expand

"""
krylov_expand --- Basis expansion via Hamiltonian connections
============================================================

Provides :func:`expand_basis_via_connections`, which grows a configuration
basis by following the off-diagonal structure of the Hamiltonian.  Starting
from a set of reference configurations, the function collects all
Hamiltonian-connected states, ranks them by coupling strength, and adds
unique new configurations up to a specified cap.

Supports two-hop expansion: first hop discovers singles/doubles from seed
configs, second hop expands from those to reach up to quadruples.
"""

from __future__ import annotations

import logging
from typing import Any

import numpy as np
import torch

__all__ = [
    "expand_basis_via_connections",
]

logger = logging.getLogger(__name__)


def _select_reference_configs(
    basis: torch.Tensor, hamiltonian: Any, n_ref: int
) -> torch.Tensor:
    """Select reference configurations with the lowest diagonal energy.

    Parameters
    ----------
    basis : torch.Tensor
        Current basis configurations, shape ``(n_basis, n_sites)``.
    hamiltonian : Hamiltonian
        The system Hamiltonian (must implement ``diagonal_element``).
    n_ref : int
        Number of reference configurations to select.

    Returns
    -------
    torch.Tensor
        Selected reference configurations, shape ``(n_ref, n_sites)``.
    """
    n_basis = basis.shape[0]
    n_ref = min(n_ref, n_basis)

    if n_ref == n_basis:
        return basis

    diag_energies = torch.zeros(n_basis, dtype=torch.float64)
    for i in range(n_basis):
        diag_energies[i] = hamiltonian.diagonal_element(basis[i])

    _, indices = torch.topk(diag_energies, n_ref, largest=False)
    return basis[indices]


[docs] def expand_basis_via_connections( basis: torch.Tensor, hamiltonian: Any, max_new: int = 500, n_ref: int | None = None, coupling_rank: bool = True, ) -> torch.Tensor: """Expand a configuration basis by following Hamiltonian connections. Selects ``n_ref`` reference configurations from the basis, collects all states connected to them via the Hamiltonian, ranks by coupling strength, removes duplicates and states already in the basis, and returns the expanded basis with up to ``max_new`` new configurations. Performs a two-hop expansion when budget allows: first hop discovers single/double excitations from seed configs, second hop expands from those to reach quadruples. Parameters ---------- basis : torch.Tensor Current basis configurations, shape ``(n_basis, n_sites)`` with integer entries. hamiltonian The system Hamiltonian. Must implement ``diagonal_element`` and ``get_connections``. max_new : int, optional Maximum number of new configurations to add (default ``500``). n_ref : int or None, optional Number of reference configurations. Defaults to ``min(len(basis), 50)``. coupling_rank : bool, optional If ``True``, rank new configs by max ``|H_ij|`` coupling strength and keep top ``max_new`` (default ``True``). Returns ------- torch.Tensor Expanded basis configurations, shape ``(n_basis + n_added, n_sites)`` where ``n_added <= max_new``. """ if basis.shape[0] == 0: logger.warning("expand_basis_via_connections: empty basis.") return basis if max_new < 1: return basis if isinstance(basis, np.ndarray): basis = torch.from_numpy(basis).long() basis = basis.cpu().long() existing_keys = {row.tobytes() for row in basis.numpy()} if n_ref is None: n_ref = min(len(basis), 50) refs = _select_reference_configs(basis, hamiltonian, n_ref) # First hop new_map, new_configs_map = _collect_connections(refs, hamiltonian, existing_keys) if not new_map: logger.debug( "expand_basis_via_connections: no new configs from %d refs.", refs.shape[0], ) return basis keys_list = list(new_map.keys()) new_configs_list = [new_configs_map[k] for k in keys_list] new_couplings = np.array([new_map[k] for k in keys_list]) new_tensor = torch.stack(new_configs_list) new_tensor, new_couplings = _truncate_by_coupling( new_tensor, new_couplings, max_new, coupling_rank ) expanded = torch.cat([basis, new_tensor], dim=0) # Second hop (if budget allows) if len(new_tensor) > 0 and max_new > len(new_tensor): remaining = max_new - len(new_tensor) second_refs = new_tensor[: min(50, len(new_tensor))] hop2_map, hop2_configs_map = _collect_connections( second_refs, hamiltonian, existing_keys ) if hop2_map: keys2 = list(hop2_map.keys()) hop2_tensor = torch.stack([hop2_configs_map[k] for k in keys2]) hop2_couplings = np.array([hop2_map[k] for k in keys2]) hop2_tensor, _ = _truncate_by_coupling( hop2_tensor, hop2_couplings, remaining, coupling_rank ) if len(hop2_tensor) > 0: expanded = torch.cat([expanded, hop2_tensor], dim=0) n_added = expanded.shape[0] - basis.shape[0] logger.info( "expand_basis_via_connections: added %d configs (basis %d -> %d, %d refs).", n_added, basis.shape[0], expanded.shape[0], refs.shape[0], ) return expanded
def _collect_connections( refs: torch.Tensor, hamiltonian: Any, existing_keys: set, ) -> tuple[dict, dict]: """Collect connected configurations from reference states. For each reference configuration, retrieves Hamiltonian-connected states via ``hamiltonian.get_connections`` and tracks the maximum coupling strength ``|H_ij|`` for each new configuration. Parameters ---------- refs : torch.Tensor Reference configurations, shape ``(n_ref, n_sites)``. hamiltonian The system Hamiltonian (must implement ``get_connections``). existing_keys : set Set of byte-keys for configurations already in the basis. **Modified in-place**: newly discovered keys are added. Returns ------- new_map : dict Mapping ``config_key (bytes) -> max |H_ij|`` coupling strength. new_configs_map : dict Mapping ``config_key (bytes) -> config tensor``. """ new_map: dict[bytes, float] = {} new_configs_map: dict[bytes, torch.Tensor] = {} for ref in refs: try: connected, elements = hamiltonian.get_connections(ref) except Exception as e: logger.debug("get_connections failed: %s", e) continue if connected is None or len(connected) == 0: continue connected = connected.cpu().long() if elements is not None: elements_np = elements.detach().cpu().numpy() else: elements_np = np.ones(len(connected)) for i in range(len(connected)): key = connected[i].numpy().tobytes() if key in existing_keys: continue coupling = abs(float(elements_np[i])) if key not in new_map or coupling > new_map[key]: new_map[key] = coupling new_configs_map[key] = connected[i] existing_keys.update(new_map.keys()) return new_map, new_configs_map def _truncate_by_coupling( tensor: torch.Tensor, couplings: np.ndarray, max_new: int, coupling_rank: bool, ) -> tuple[torch.Tensor, np.ndarray]: """Keep at most ``max_new`` configurations ranked by coupling strength. Parameters ---------- tensor : torch.Tensor Candidate configurations, shape ``(n_candidates, n_sites)``. couplings : np.ndarray Coupling strengths, shape ``(n_candidates,)``. max_new : int Maximum number of configurations to retain. coupling_rank : bool If ``True``, sort by descending coupling and keep the top ``max_new``. If ``False``, keep the first ``max_new`` in their original order. Returns ------- tensor : torch.Tensor Retained configurations, shape ``(<= max_new, n_sites)``. couplings : np.ndarray Corresponding coupling strengths. """ if len(tensor) <= max_new: return tensor, couplings if coupling_rank: top_idx = np.argsort(couplings)[::-1][:max_new].copy() else: top_idx = np.arange(max_new) return tensor[top_idx], couplings[top_idx]