"""
connection_cache --- Hash-based cache for Hamiltonian connections
================================================================
Provides :class:`ConnectionCache`, a dictionary-backed cache that maps
configuration hashes to their Hamiltonian-connected configurations and
matrix elements. This avoids redundant calls to
:meth:`Hamiltonian.get_connections` when the same configuration is
encountered multiple times during iterative basis expansion or sampling.
The hash function converts a binary occupation vector to a unique integer
by interpreting it as a binary number (via powers of 2). A powers tensor
is computed once and reused across all lookups for efficiency.
Eviction follows **LRU** (least-recently-used) order: the entry that was
neither accessed nor inserted most recently is evicted first.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import torch
if TYPE_CHECKING:
from qvartools.hamiltonians.hamiltonian import Hamiltonian
__all__ = [
"ConnectionCache",
]
logger = logging.getLogger(__name__)
[docs]
class ConnectionCache:
"""Hash-based LRU cache for Hamiltonian connections.
Stores ``(connected_configs, matrix_elements)`` tuples keyed by the
integer hash of each configuration. Provides O(1) lookup and
LRU eviction when the cache exceeds ``max_size``.
Parameters
----------
max_size : int, optional
Maximum number of entries the cache may hold (default ``100000``).
When the cache is full, the **least-recently-used** entry is
evicted on the next :meth:`put` or :meth:`get_or_compute` call.
Attributes
----------
max_size : int
Maximum cache capacity.
Examples
--------
>>> cache = ConnectionCache(max_size=1000)
>>> config = torch.tensor([1, 0, 1, 0])
>>> cache.put(config, connected, elements)
>>> result = cache.get(config)
"""
def __init__(self, max_size: int = 100_000) -> None:
if max_size < 1:
raise ValueError(f"max_size must be >= 1, got {max_size}")
self.max_size: int = max_size
self._cache: dict[int | tuple[int, ...], tuple[torch.Tensor, torch.Tensor]] = {}
self._hits: int = 0
self._misses: int = 0
self._powers: torch.Tensor | None = None
self._powers_n: int = 0
self._powers_device: torch.device = torch.device("cpu")
def _get_powers(self, n: int, device: torch.device) -> torch.Tensor:
"""Return the cached powers-of-2 tensor for *n* sites.
The tensor is created once and reused. If *n* or *device* change,
the tensor is rebuilt (or moved) accordingly.
"""
if self._powers is not None and self._powers_n == n:
if self._powers.device == device:
return self._powers
self._powers = self._powers.to(device)
self._powers_device = device
return self._powers
self._powers = torch.tensor(
[1 << k for k in range(n - 1, -1, -1)],
dtype=torch.int64,
device=device,
)
self._powers_n = n
self._powers_device = device
return self._powers
_MAX_SITES_INT64 = 63 # int64 can represent 2^0 .. 2^62 without overflow
def _hash(self, config: torch.Tensor) -> int | tuple[int, ...]:
"""Convert a binary configuration to a hashable key.
For ``n_sites <= 63`` returns a single ``int`` (powers-of-2 hash).
For ``n_sites >= 64`` falls back to a ``tuple[int, ...]`` to avoid
int64 overflow.
"""
n = config.shape[0]
if n > self._MAX_SITES_INT64:
return tuple(config.long().tolist())
powers = self._get_powers(n, config.device)
return int((config.to(torch.int64) * powers).sum().item())
[docs]
def hash_batch(self, configs: torch.Tensor) -> torch.Tensor | list[tuple[int, ...]]:
"""Hash a batch of configurations.
For ``n_sites <= 63`` uses a single matmul and returns an int64
tensor. For ``n_sites >= 64`` falls back to per-config tuple
keys and returns a list.
Parameters
----------
configs : torch.Tensor
Binary configurations, shape ``(n_configs, n_sites)``.
Returns
-------
torch.Tensor or list
Integer hashes (tensor) or tuple keys (list).
"""
n_sites = configs.shape[1]
if n_sites > self._MAX_SITES_INT64:
return [tuple(row.long().tolist()) for row in configs]
powers = self._get_powers(n_sites, configs.device)
return configs.to(torch.int64) @ powers
[docs]
def get_batch(
self, configs: torch.Tensor
) -> list[tuple[torch.Tensor, torch.Tensor] | None]:
"""Look up cached connections for a batch of configurations.
Parameters
----------
configs : torch.Tensor
Binary configurations, shape ``(n_configs, n_sites)``.
Returns
-------
list
One entry per config: ``(connected, elements)`` if cached,
``None`` otherwise. Cache hits are promoted to
most-recently-used.
"""
raw_hashes = self.hash_batch(configs)
hashes = (
raw_hashes.tolist() if isinstance(raw_hashes, torch.Tensor) else raw_hashes
)
results: list[tuple[torch.Tensor, torch.Tensor] | None] = []
for h in hashes:
key = int(h) if isinstance(h, (int, float)) else h
if key in self._cache:
self._hits += 1
self._touch(key)
results.append(self._cache[key])
else:
self._misses += 1
results.append(None)
return results
def _touch(self, key: int) -> None:
"""Move *key* to the end of the dict (mark as most-recently-used)."""
value = self._cache.pop(key)
self._cache[key] = value
def _evict(self) -> None:
"""Evict the least-recently-used entry (first in dict order)."""
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
[docs]
def get(
self,
config: torch.Tensor,
hamiltonian: Hamiltonian | None = None,
) -> tuple[torch.Tensor, torch.Tensor] | None:
"""Look up cached connections for a configuration.
On a cache hit the entry is promoted to most-recently-used.
If *hamiltonian* is provided the call behaves like
:meth:`get_or_compute` for backward compatibility: a cache miss
triggers ``hamiltonian.get_connections(config)``, caches the
result, and returns it (never returns ``None``).
Parameters
----------
config : torch.Tensor
Binary configuration vector, shape ``(n_sites,)``.
hamiltonian : Hamiltonian or None, optional
If given, compute and cache on miss instead of returning
``None``.
Returns
-------
tuple of (torch.Tensor, torch.Tensor) or None
``(connected_configs, matrix_elements)`` if found (or
computed), otherwise ``None`` when *hamiltonian* is not
provided.
"""
key = self._hash(config)
if key in self._cache:
self._hits += 1
self._touch(key)
return self._cache[key]
self._misses += 1
if hamiltonian is not None:
connected, elements = hamiltonian.get_connections(config)
self.put(config, connected, elements)
return connected, elements
return None
[docs]
def put(
self,
config: torch.Tensor,
connections: torch.Tensor,
elements: torch.Tensor,
) -> None:
"""Store connections for a configuration in the cache.
If the cache is at capacity, the least-recently-used entry is
evicted.
Parameters
----------
config : torch.Tensor
Binary configuration vector, shape ``(n_sites,)``.
connections : torch.Tensor
Connected configurations, shape ``(n_conn, n_sites)``.
elements : torch.Tensor
Matrix elements, shape ``(n_conn,)``.
"""
key = self._hash(config)
if key in self._cache:
# Update existing entry and promote to most-recent
self._cache.pop(key)
elif len(self._cache) >= self.max_size:
self._evict()
self._cache[key] = (connections.clone(), elements.clone())
[docs]
def get_or_compute(
self,
config: torch.Tensor,
hamiltonian: Hamiltonian,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Retrieve connections, computing and caching if absent.
This unifies the lookup-or-compute pattern: if the configuration
is already cached its entry is returned (and promoted to
most-recently-used); otherwise ``hamiltonian.get_connections``
is called, the result is cached, and then returned.
Parameters
----------
config : torch.Tensor
Single configuration, shape ``(num_sites,)``.
hamiltonian : Hamiltonian
The Hamiltonian to query for connections on a cache miss.
Returns
-------
connected_configs : torch.Tensor
Connected configurations, shape ``(n_conn, num_sites)``.
matrix_elements : torch.Tensor
Corresponding matrix elements, shape ``(n_conn,)``.
"""
result = self.get(config)
if result is not None:
return result
connected, elements = hamiltonian.get_connections(config)
self.put(config, connected, elements)
return connected, elements
[docs]
def clear(self) -> None:
"""Remove all entries from the cache and reset statistics."""
self._cache.clear()
self._hits = 0
self._misses = 0
[docs]
def stats(self) -> dict[str, Any]:
"""Return cache performance statistics.
Returns
-------
dict
Dictionary with keys:
- ``"hits"`` : int --- number of successful lookups.
- ``"misses"`` : int --- number of failed lookups.
- ``"hit_rate"`` : float --- fraction of lookups that were hits
(0.0 if no lookups have been made).
- ``"size"`` : int --- current number of cached entries.
"""
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0.0
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": hit_rate,
"size": len(self._cache),
}
def __len__(self) -> int:
"""Return the number of cached entries."""
return len(self._cache)
def __contains__(self, config: torch.Tensor) -> bool:
"""Check whether a configuration is in the cache.
Parameters
----------
config : torch.Tensor
Binary configuration vector, shape ``(n_sites,)``.
Returns
-------
bool
``True`` if the configuration hash is in the cache.
"""
key = self._hash(config)
return key in self._cache
def __repr__(self) -> str:
return (
f"ConnectionCache(size={len(self._cache)}, "
f"max_size={self.max_size}, "
f"hit_rate={self.stats()['hit_rate']:.2%})"
)