neuralqx.driver.multi_state_vmc module

Multi-state VMC driver with an optional orthogonality regularisation term.

This module extends a standard variational Monte Carlo (VMC) optimisation loop to train multiple independent MCState instances jointly via a MultiMCState container.

At each optimisation step the driver computes an energy gradient for every state. When enabled, an additional pairwise penalty based on a fidelity-like overlap estimator is added to encourage mutual orthogonality between different states.

A dedicated JAX kernel is provided to estimate the pairwise overlap quantity and its gradients with respect to both states’ parameters using a shared joint sample batch.

fidelity_expect_and_grad_joint(apply_fun_i, apply_fun_j, machine_pow_i, machine_pow_j, params_i, model_state_i, params_j, model_state_j, sigma_i, sigma_j)

Estimate a fidelity-like overlap quantity and its gradients for two variational states.

This JIT-compiled kernel forms a joint configuration batch by concatenating samples from two states and evaluates a local ratio estimator on crossed configurations. The result is a scalar overlap measure (real-valued) together with auxiliary statistics and gradients with respect to both parameter sets.

For the common choice machine_pow == 2 (Born sampling), the expectation value matches the normalised pure-state fidelity

\[F_{ij} = \frac{|\langle \psi_i | \psi_j \rangle|^2} {\langle \psi_i | \psi_i \rangle\,\langle \psi_j | \psi_j \rangle}.\]

More generally (for machine_pow != 2), the returned quantity is still a consistent expectation value under the chosen sampling distributions but is not equal to the Hilbert-space fidelity unless machine_pow == 2.

Gradients are produced through NetKet’s VJP wrapper and therefore follow the same distributed semantics as NetKet expectation gradients.

Parameters:
  • apply_fun_i – Apply function for state i mapping variables and samples to log(psi).

  • apply_fun_j – Apply function for state j mapping variables and samples to log(psi).

  • machine_pow_i – Sampling power used for state i in the joint log density.

  • machine_pow_j – Sampling power used for state j in the joint log density.

  • params_i – Parameter pytree for state i.

  • model_state_i – Model-state pytree for state i (everything except parameters).

  • params_j – Parameter pytree for state j.

  • model_state_j – Model-state pytree for state j (everything except parameters).

  • sigma_i – Samples from state i (may include chain dimensions).

  • sigma_j – Samples from state j (may include chain dimensions).

Returns:

(fid_val, fid_stats, grads) where fid_val is a real scalar, fid_stats is a NetKet statistics object, and grads is a dict-like pytree with keys "i" and "j" containing gradients for state i and state j.

class MultiStateVMC(variational_state, hamiltonian, optimizer, *, preconditioner=<netket.optimizer.preconditioner.IdentityPreconditioner object>, lambda_ortho=1.0)

Bases: VMC

VMC driver for MultiMCState with an optional orthogonality penalty.

For each contained state, this driver computes energy statistics and gradients of the Hamiltonian objective. If lambda_ortho is nonzero and more than one state is present, it additionally applies a pairwise penalty based on a fidelity-like overlap estimator to encourage the states to become mutually orthogonal.

Preconditioning is applied independently per state. Interpreted as a joint optimisation problem over all parameters, this corresponds to a block-diagonal preconditioner.

Parameters:
  • variational_state (MultiMCState) – Multi-state variational object containing multiple independent states.

  • hamiltonian (Union[AbstractOperator, list]) – Operator or list of operators defining the shared objective.

  • optimizer (Any) – Optimiser used to update parameters from the (preconditioned) gradients.

  • preconditioner (Callable[[VariationalState, Any, Any | None], Any]) – Preconditioner applied per state to transform raw gradients.

  • lambda_ortho (float) – Strength of the orthogonality penalty. Set to 0 to disable.

Raises:

TypeError – If any operator acts on a Hilbert space incompatible with the variational state.

property states: list[MCState]

Return the list of underlying per-state MCState instances.

This is a convenience view into the states managed by the multi-state variational state.

Returns:

The list of contained Monte Carlo variational states.

estimate(observables)

Estimate observables and flatten per-state results into a logging-friendly dictionary.

This overrides the base driver’s estimate() method to return a flat mapping from string keys to scalar statistics objects or values.

For each observable leaf that yields a list of per-state statistics, entries are created as:

"<name> (state i)" -> Stats

Optionally, a simple aggregate curve is added as "<name>/sum_mean" containing the sum of per-state means when available. This aggregate is intended for quick diagnostics and is not a statistically rigorous combination of per-state uncertainties.

Parameters:

observables – Observable or pytree of observables to estimate. If None, an empty set of observables is assumed.

Returns:

Flat dictionary mapping names to per-state statistics objects or scalar values.