"""
nf_skqd --- NF-SKQD: Normalizing Flow Sample-based Krylov Quantum Diagonalization
==================================================================================
A faithful classical analog of the quantum SKQD algorithm (Yu et al., 2025),
replacing quantum circuit time evolution with NF distribution evolution.
Quantum SKQD
Krylov subspace = {|psi_0>, U|psi_0>, U^2|psi_0>, ..., U^k|psi_0>}
where U = exp(-iH dt) via Trotter circuits. Sample each U^k|psi_0>
to obtain bitstrings, combine ALL bitstrings into a cumulative basis,
project H, and diagonalize.
NF-SKQD
Krylov subspace = {NF_0, NF_1, NF_2, ..., NF_k}.
NF_0 is an untrained (random / HF-biased) distribution. NF_{k+1} is
obtained by partially updating NF_k toward the current ground-state
eigenvector |Phi_k>. Sample each NF_k, add to a cumulative basis, and
diagonalize.
Key design principles
1. **Cumulative basis** -- never discard configs from previous powers.
2. **Partial NF update** -- only a few gradient steps per power (mimics
a small Trotter time step).
3. **No H-connection expansion** -- the basis grows purely from NF
sampling.
4. **Energy monotonicity** -- each power can only improve or maintain
the energy.
References
----------
.. [1] Yu et al. (2025) "Quantum-Centric Algorithm for Sample-Based Krylov
Diagonalization", arXiv:2501.09702
.. [2] Pellow-Jarman et al. (2025) "HIVQE", arXiv:2503.06292
.. [3] Robledo-Moreno et al. (2024) "Chemistry beyond exact solutions"
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any, Protocol, runtime_checkable
import numpy as np
import torch
from qvartools._utils.hashing.config_hash import config_integer_hash
from qvartools.solvers.solver import Solver, SolverResult
__all__ = [
"NFSKQDConfig",
"NFSKQDSolver",
]
# ---------------------------------------------------------------------------
# Flow model protocol
# ---------------------------------------------------------------------------
@runtime_checkable
class FlowModel(Protocol):
"""Structural type for objects accepted as *flow_model*."""
def sample(self, n: int, **kwargs: Any) -> Any: ...
def log_prob(self, x: torch.Tensor) -> torch.Tensor: ...
def parameters(self) -> Any: ...
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class NFSKQDConfig:
"""Configuration for the NF-SKQD solver.
Parameters
----------
n_krylov_powers : int
Number of Krylov "time steps".
n_samples_per_power : int
Number of NF samples drawn at each power.
nf_steps_per_power : int
Gradient steps applied to the NF per power (partial update).
nf_lr : float
Learning rate for the NF Adam optimizer.
wf_weight : float
Weight for the wavefunction-matching (KL) loss term.
energy_weight : float
Weight for the REINFORCE energy loss term.
entropy_weight : float
Weight for the entropy regularisation term.
initial_temperature : float
Sampling temperature at the first power.
final_temperature : float
Sampling temperature at the last power.
warmup_powers : int
Number of initial powers during which NF updates are skipped.
max_basis_size : int
Hard cap on cumulative basis size.
convergence_threshold : float
Energy change threshold (Ha) for early stopping.
"""
n_krylov_powers: int = 10
n_samples_per_power: int = 2000
nf_steps_per_power: int = 20
nf_lr: float = 1e-3
wf_weight: float = 1.0
energy_weight: float = 0.1
entropy_weight: float = 0.05
initial_temperature: float = 2.0
final_temperature: float = 0.5
warmup_powers: int = 0
max_basis_size: int = 10_000
convergence_threshold: float = 1e-6
# ---------------------------------------------------------------------------
# Solver
# ---------------------------------------------------------------------------
[docs]
class NFSKQDSolver(Solver):
"""NF-SKQD: faithful NF analog of quantum SKQD with cumulative basis.
Each Krylov power *k*:
1. Sample from the current NF distribution NF_k.
2. Add unique, particle-number-valid configs to the cumulative basis.
3. Diagonalize the projected Hamiltonian (Rayleigh--Ritz).
4. Partially update the NF toward the ground-state eigenvector.
Parameters
----------
flow_model : object
A normalizing-flow model exposing ``sample()``, ``log_prob()``,
and ``parameters()`` methods.
config : NFSKQDConfig, optional
Solver hyper-parameters. Uses defaults when omitted.
"""
def __init__(
self,
flow_model: FlowModel,
config: NFSKQDConfig | None = None,
) -> None:
self.flow = flow_model
self.config = config or NFSKQDConfig()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def solve(self, hamiltonian: Any, mol_info: dict[str, Any]) -> SolverResult:
"""Run NF-SKQD and return the ground-state energy estimate.
Parameters
----------
hamiltonian : MolecularHamiltonian
Hamiltonian with ``get_hf_state``, ``matrix_elements_fast``,
``diagonal_element``, and ``diagonal_elements_batch`` methods.
mol_info : dict
Molecular metadata (unused by this solver, kept for API compat).
Returns
-------
SolverResult
"""
t0 = time.perf_counter()
cfg = self.config
device = "cpu"
n_orb = hamiltonian.n_orbitals
n_alpha = hamiltonian.n_alpha
n_beta = hamiltonian.n_beta
optimizer = torch.optim.Adam(self.flow.parameters(), lr=cfg.nf_lr)
# Seed with HF state
hf = hamiltonian.get_hf_state().unsqueeze(0).to(device)
cumulative_basis = hf.clone()
basis_hashes: set[int] = set(config_integer_hash(cumulative_basis))
energy_history: list[float] = []
basis_size_history: list[int] = []
samples_per_power: list[int] = []
prev_energy = float("inf")
best_energy = float("inf")
converged = False
last_k = 0
for k in range(cfg.n_krylov_powers):
last_k = k
# Temperature annealing
progress = k / max(cfg.n_krylov_powers - 1, 1)
temperature = cfg.initial_temperature + progress * (
cfg.final_temperature - cfg.initial_temperature
)
# Step 1: Sample from current NF
with torch.no_grad():
try:
sample_out = self.flow.sample(
cfg.n_samples_per_power, temperature=temperature
)
except TypeError:
sample_out = self.flow.sample(cfg.n_samples_per_power)
if sample_out[0].dim() == 1:
raw_configs = sample_out[1].long().to(device)
else:
raw_configs = sample_out[0].long().to(device)
# Particle-number filter
alpha_counts = raw_configs[:, :n_orb].sum(dim=1)
beta_counts = raw_configs[:, n_orb:].sum(dim=1)
valid = (alpha_counts == n_alpha) & (beta_counts == n_beta)
new_configs = raw_configs[valid]
# Deduplicate and add to cumulative basis
n_new = 0
if len(new_configs) > 0:
new_unique = torch.unique(new_configs, dim=0)
new_hashes = config_integer_hash(new_unique)
truly_new: list[torch.Tensor] = []
for idx, h in enumerate(new_hashes):
if h not in basis_hashes:
truly_new.append(new_unique[idx])
basis_hashes.add(h)
if truly_new:
new_batch = torch.stack(truly_new)
cumulative_basis = torch.cat([cumulative_basis, new_batch], dim=0)
n_new = len(truly_new)
samples_per_power.append(n_new)
# Enforce max basis size
if len(cumulative_basis) > cfg.max_basis_size:
cumulative_basis = torch.cat(
[hf, cumulative_basis[-(cfg.max_basis_size - 1) :]],
dim=0,
)
basis_hashes = set(config_integer_hash(cumulative_basis))
# Step 2: Diagonalize
if len(cumulative_basis) < 2:
e0 = float(hamiltonian.diagonal_element(cumulative_basis[0]))
psi0 = np.array([1.0])
else:
H_proj = hamiltonian.matrix_elements_fast(cumulative_basis)
H_np = H_proj.cpu().numpy().astype(np.float64)
H_np = 0.5 * (H_np + H_np.T)
if len(H_np) <= 2000:
eigenvalues, eigenvectors = np.linalg.eigh(H_np)
else:
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
eigenvalues, eigenvectors = eigsh(csr_matrix(H_np), k=1, which="SA")
e0 = float(eigenvalues[0])
psi0 = eigenvectors[:, 0]
energy_history.append(e0)
basis_size_history.append(len(cumulative_basis))
if e0 < best_energy:
best_energy = e0
# Step 3: Partial NF update
if len(cumulative_basis) >= 2 and k >= cfg.warmup_powers:
_evolve_nf(
self.flow,
cumulative_basis,
psi0,
e0,
hamiltonian,
optimizer,
cfg,
)
# Step 4: Convergence check
delta_e = abs(e0 - prev_energy)
prev_energy = e0
if delta_e < cfg.convergence_threshold and k > 0:
converged = True
break
wall_time = time.perf_counter() - t0
return SolverResult(
diag_dim=len(cumulative_basis),
wall_time=wall_time,
method="NF-SKQD",
converged=converged,
energy=best_energy if best_energy < float("inf") else None,
metadata={
"n_krylov_powers": last_k + 1,
"energy_history": energy_history,
"basis_size_history": basis_size_history,
"samples_per_power": samples_per_power,
},
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _evolve_nf(
flow: FlowModel,
basis_configs: torch.Tensor,
psi0: np.ndarray,
e0: float,
hamiltonian: Any,
optimizer: torch.optim.Optimizer,
cfg: NFSKQDConfig,
) -> None:
"""Partially update the NF distribution toward the ground-state eigenvector.
The loss combines:
* **Wavefunction matching** -- KL(|Phi|^2 || p_NF).
* **REINFORCE energy** -- gradient estimator for lowering energy.
* **Entropy regularisation** -- maintains exploration.
Only a few gradient steps are taken (not full convergence), so the NF
distribution shifts gradually across Krylov powers.
"""
basis_float = basis_configs.float()
weights = torch.from_numpy(psi0**2).float()
weights = weights / weights.sum()
with torch.no_grad():
diag_energies = torch.from_numpy(
np.asarray(
hamiltonian.diagonal_elements_batch(basis_configs),
dtype=np.float64,
)
).float()
advantage = diag_energies - e0
for _step in range(cfg.nf_steps_per_power):
optimizer.zero_grad()
log_probs = flow.log_prob(basis_float)
loss_wf = -(weights * log_probs).sum()
loss_energy = (weights * advantage * log_probs).sum()
loss_entropy = log_probs.mean()
loss = (
cfg.wf_weight * loss_wf
+ cfg.energy_weight * loss_energy
+ cfg.entropy_weight * loss_entropy
)
loss.backward()
torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
optimizer.step()