"""
hi_nqs_sqd --- HI+NQS+SQD: iterative self-consistent NQS-SQD loop
====================================================================
Iterative pipeline that trains an autoregressive transformer NQS to
sample configurations, solves via subspace diagonalisation (SQD), feeds
the eigenvector back as a teacher signal, and repeats until convergence.
At each iteration the NQS samples are converted to IBM SQD format,
optionally processed through ``qiskit_addon_sqd`` configuration recovery,
and diagonalised with the internal GPU solver.
External dependencies (``qiskit_addon_sqd``) are optional.
Functions
---------
run_hi_nqs_sqd
Execute the full HI+NQS+SQD pipeline.
"""
from __future__ import annotations
import logging
import math
import time
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from qvartools._utils.formatting.bitstring_format import (
configs_to_ibm_format,
vectorized_dedup,
)
from qvartools._utils.gpu.diagnostics import gpu_solve_fermion
from qvartools.nqs.transformer.autoregressive import AutoregressiveTransformer
from qvartools.solvers.solver import SolverResult
__all__ = [
"HINQSSQDConfig",
"run_hi_nqs_sqd",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Optional dependency guards
# ---------------------------------------------------------------------------
try:
from qiskit_addon_sqd.configuration_recovery import (
recover_configurations, # type: ignore[import-untyped]
)
from qiskit_addon_sqd.fermion import (
solve_fermion as ibm_solve_fermion, # type: ignore[import-untyped]
)
_IBM_SQD_AVAILABLE = True
except ImportError:
recover_configurations = None # type: ignore[assignment]
ibm_solve_fermion = None # type: ignore[assignment]
_IBM_SQD_AVAILABLE = False
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class HINQSSQDConfig:
"""Configuration for the HI+NQS+SQD pipeline.
Parameters
----------
n_iterations : int
Number of outer self-consistent iterations (default ``10``).
n_samples_per_iter : int
NQS samples drawn per iteration (default ``10_000``).
n_batches : int
Configuration-recovery batches per iteration (default ``5``).
max_configs_per_batch : int
Maximum configs retained per batch (default ``5000``).
energy_tol : float
Convergence threshold in Hartree (default ``1e-5``).
nqs_lr : float
NQS optimiser learning rate (default ``1e-3``).
nqs_train_epochs : int
NQS training epochs per iteration (default ``50``).
embed_dim : int
Transformer embedding dimension (default ``64``).
n_heads : int
Number of attention heads (default ``4``).
n_layers : int
Number of transformer layers per channel (default ``4``).
temperature : float
NQS sampling temperature (default ``1.0``).
use_ibm_solver : bool
Use IBM ``solve_fermion`` when available (default ``False``).
Set to ``True`` only if ``qiskit_addon_sqd`` is installed with a
compatible API version.
device : str
Torch device string (default ``"cpu"``).
"""
n_iterations: int = 10
n_samples_per_iter: int = 10_000
n_batches: int = 5
max_configs_per_batch: int = 5000
energy_tol: float = 1e-5
nqs_lr: float = 1e-3
nqs_train_epochs: int = 50
embed_dim: int = 64
n_heads: int = 4
n_layers: int = 4
temperature: float = 1.0
use_ibm_solver: bool = False
device: str = "cpu"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _train_nqs_teacher(
nqs: AutoregressiveTransformer,
configs: torch.Tensor,
coeffs: np.ndarray,
n_orb: int,
lr: float,
epochs: int,
) -> list[float]:
"""Train NQS using eigenvector coefficients as teacher signal.
Minimises the KL divergence between the teacher distribution
``p_teacher(x) = |c_x|^2`` and the NQS distribution.
Parameters
----------
nqs : AutoregressiveTransformer
Transformer NQS.
configs : torch.Tensor
Basis configurations, shape ``(n_basis, 2*n_orb)``.
coeffs : np.ndarray
Eigenvector coefficients, shape ``(n_basis,)``.
n_orb : int
Spatial orbitals per spin channel.
lr : float
Learning rate.
epochs : int
Training epochs.
Returns
-------
list of float
Per-epoch loss values.
"""
device = next(nqs.parameters()).device
# Build teacher distribution: p(x) = |c_x|^2 / Z
weights = np.abs(coeffs) ** 2
total = weights.sum()
if total > 0:
weights = weights / total
weights_t = torch.from_numpy(weights).float().to(device)
configs_dev = configs.to(device).long()
alpha = configs_dev[:, :n_orb]
beta = configs_dev[:, n_orb:]
optimiser = torch.optim.Adam(nqs.parameters(), lr=lr)
losses: list[float] = []
nqs.train()
for _epoch in range(epochs):
optimiser.zero_grad()
log_probs = nqs.log_prob(alpha, beta)
# Weighted NLL: - sum_x p_teacher(x) * log q(x)
loss = -(weights_t * log_probs).sum()
loss.backward()
optimiser.step()
losses.append(float(loss.item()))
nqs.eval()
return losses
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
[docs]
def run_hi_nqs_sqd(
hamiltonian: Any,
mol_info: dict[str, Any],
config: HINQSSQDConfig | None = None,
*,
initial_basis: torch.Tensor | None = None,
) -> SolverResult:
"""Execute the HI+NQS+SQD pipeline.
Parameters
----------
hamiltonian : Hamiltonian
Molecular Hamiltonian.
mol_info : dict
Molecular metadata. Required keys: ``"n_orbitals"``,
``"n_alpha"``, ``"n_beta"``, ``"n_qubits"``.
config : HINQSSQDConfig or None
Pipeline configuration.
initial_basis : torch.Tensor or None, optional
Pre-computed configurations to seed the cumulative basis
(e.g., from NF+DCI Stage 1-2). Shape ``(n_configs, n_qubits)``.
If ``None`` (default), starts from an empty basis.
Returns
-------
SolverResult
Energy, timing, convergence, and per-iteration metadata.
Raises
------
ValueError
If ``mol_info`` is missing required keys, or if ``initial_basis``
has wrong shape, non-binary values, or floating-point/complex dtype.
RuntimeError
If all diagonalisation batches produce non-finite energies.
"""
cfg = config or HINQSSQDConfig()
# Support mol_info with or without orbital counts (fall back to hamiltonian)
_integrals = getattr(hamiltonian, "integrals", None)
n_orb: int = mol_info.get(
"n_orbitals", _integrals.n_orbitals if _integrals else None
)
n_alpha: int = mol_info.get("n_alpha", _integrals.n_alpha if _integrals else None)
n_beta: int = mol_info.get("n_beta", _integrals.n_beta if _integrals else None)
if n_orb is None or n_alpha is None or n_beta is None:
raise ValueError(
"n_orbitals, n_alpha, and n_beta must be provided via mol_info "
"or hamiltonian.integrals. Got: "
f"n_orbitals={n_orb}, n_alpha={n_alpha}, n_beta={n_beta}"
)
n_qubits: int = mol_info.get("n_qubits", 2 * n_orb)
device = torch.device(cfg.device)
logger.info(
"run_hi_nqs_sqd: %d orbitals, %d alpha, %d beta",
n_orb,
n_alpha,
n_beta,
)
t_start = time.perf_counter()
# --- Build NQS ---
nqs = AutoregressiveTransformer(
n_orbitals=n_orb,
n_alpha=n_alpha,
n_beta=n_beta,
embed_dim=cfg.embed_dim,
n_heads=cfg.n_heads,
n_layers=cfg.n_layers,
).to(device)
nqs.eval()
# --- Occupancies (uniform prior) ---
occ_alpha = np.full(n_orb, n_alpha / n_orb)
occ_beta = np.full(n_orb, n_beta / n_orb)
# --- Cumulative basis (warm-start from initial_basis if provided) ---
if initial_basis is not None:
# Validate raw input before any cast (fail-fast)
if initial_basis.is_floating_point() or initial_basis.is_complex():
raise ValueError(
f"initial_basis must be integer or bool dtype (binary occupations), "
f"got {initial_basis.dtype}"
)
if initial_basis.ndim != 2 or initial_basis.shape[1] != n_qubits:
raise ValueError(
f"initial_basis must have shape (n_configs, {n_qubits}), "
f"but got {tuple(initial_basis.shape)}"
)
if not torch.all((initial_basis == 0) | (initial_basis == 1)):
raise ValueError("initial_basis must contain only binary values {0, 1}")
cumulative_basis = initial_basis.to(dtype=torch.long, device=device)
cumulative_basis = torch.unique(cumulative_basis, dim=0)
logger.info(
"Warm-starting with %d initial basis configs", cumulative_basis.shape[0]
)
else:
cumulative_basis = torch.zeros(0, n_qubits, dtype=torch.long, device=device)
energy_history: list[float] = []
best_energy = float("inf")
converged = False
for iteration in range(cfg.n_iterations):
logger.info("HI+NQS+SQD iteration %d / %d", iteration + 1, cfg.n_iterations)
# --- NQS sampling ---
with torch.no_grad():
new_configs = nqs.sample(
cfg.n_samples_per_iter, temperature=cfg.temperature
).to(device)
# Deduplicate against cumulative basis (numpy for vectorized_dedup)
if cumulative_basis.shape[0] > 0:
cb_np = cumulative_basis.cpu().numpy()
nc_np = new_configs.cpu().numpy()
deduped_np = vectorized_dedup(cb_np, nc_np)
unique_new = torch.from_numpy(deduped_np).long().to(device)
else:
unique_new = torch.unique(new_configs, dim=0)
cumulative_basis = torch.cat([cumulative_basis, unique_new], dim=0)
cumulative_basis = torch.unique(cumulative_basis, dim=0)
logger.info(
" sampled %d, %d unique new, cumulative %d",
cfg.n_samples_per_iter,
unique_new.shape[0],
cumulative_basis.shape[0],
)
# --- Batch diagonalisation ---
batch_energies: list[float] = []
best_coeffs: np.ndarray | None = None
best_batch_configs: torch.Tensor | None = None
best_batch_energy = float("inf")
latest_occs: Any = None
for _batch_idx in range(cfg.n_batches):
if cumulative_basis.shape[0] > cfg.max_configs_per_batch:
indices = torch.randperm(cumulative_basis.shape[0])[
: cfg.max_configs_per_batch
]
batch_configs = cumulative_basis[indices]
else:
batch_configs = cumulative_basis
# Optional IBM configuration recovery
if _IBM_SQD_AVAILABLE and cfg.use_ibm_solver:
ibm_data = configs_to_ibm_format(batch_configs, n_orb, n_qubits)
n_samples = ibm_data.shape[0]
uniform_probs = np.ones(n_samples) / n_samples
refined_matrix, _ = recover_configurations(
ibm_data,
uniform_probs,
(occ_alpha, occ_beta),
num_elec_a=n_alpha,
num_elec_b=n_beta,
)
e_b, sci_state, occs_b, _ = ibm_solve_fermion(
refined_matrix,
hcore=hamiltonian.integrals.h1e,
eri=hamiltonian.integrals.h2e,
)
coeffs_b = sci_state.amplitudes
else:
e_b, coeffs_b, occs_b = gpu_solve_fermion(batch_configs, hamiltonian)
e_b = float(e_b)
if not math.isfinite(e_b):
logger.warning(
"Non-finite energy %.4e in batch %d, skipping", e_b, _batch_idx
)
continue
batch_energies.append(e_b)
latest_occs = occs_b
if e_b < best_batch_energy:
best_batch_energy = e_b
best_coeffs = np.asarray(coeffs_b)
best_batch_configs = batch_configs
if not batch_energies:
raise RuntimeError(
f"All {cfg.n_batches} batches produced non-finite energies "
f"at iteration {iteration + 1}. Check Hamiltonian integrals."
)
else:
iter_energy = float(np.min(batch_energies))
energy_history.append(iter_energy)
best_energy = min(best_energy, iter_energy)
# --- Update occupancies ---
if isinstance(latest_occs, tuple) and len(latest_occs) == 2:
occ_alpha = np.clip(np.asarray(latest_occs[0], dtype=np.float64), 0.0, 1.0)
occ_beta = np.clip(np.asarray(latest_occs[1], dtype=np.float64), 0.0, 1.0)
# --- NQS teacher training ---
if best_coeffs is not None and best_batch_configs is not None:
_train_nqs_teacher(
nqs,
best_batch_configs,
best_coeffs,
n_orb,
lr=cfg.nqs_lr,
epochs=cfg.nqs_train_epochs,
)
logger.info(
" energy=%.8f best=%.8f basis=%d",
iter_energy,
best_energy,
cumulative_basis.shape[0],
)
# --- Convergence ---
if len(energy_history) >= 2:
delta = abs(energy_history[-1] - energy_history[-2])
if delta < cfg.energy_tol:
converged = True
logger.info(" converged: |dE|=%.2e", delta)
break
wall_time = time.perf_counter() - t_start
return SolverResult(
energy=best_energy,
diag_dim=int(cumulative_basis.shape[0]),
wall_time=wall_time,
method="HI+NQS+SQD",
converged=converged,
metadata={
"energy_history": energy_history,
"n_iterations": len(energy_history),
"final_basis_size": int(cumulative_basis.shape[0]),
},
)