neuralqx.experimental.driver.mvmc module

Multi-state VMC driver with an optional orthogonality regularisation term.

This module extends a standard variational Monte Carlo (VMC) optimisation loop to train multiple independent MCState instances jointly via a MultiMCState container.

At each optimisation step the driver computes an energy gradient for every state. When enabled, an additional pairwise penalty based on a fidelity-like overlap estimator is added to encourage mutual orthogonality between different states.

A dedicated JAX kernel is provided to estimate the pairwise overlap quantity and its gradients with respect to both states’ parameters using a shared joint sample batch.

fidelity_expect_and_grad_joint(apply_fun_i, apply_fun_j, machine_pow_i, machine_pow_j, params_i, model_state_i, params_j, model_state_j, sigma_i, sigma_j)

Estimate a fidelity-like overlap quantity and its gradients for two variational states.

This JIT-compiled kernel forms a joint configuration batch by concatenating samples from two states and evaluates a local ratio estimator on crossed configurations. The result is a scalar overlap measure (real-valued) together with auxiliary statistics and gradients with respect to both parameter sets.

For the common choice machine_pow == 2 (Born sampling), the expectation value matches the normalised pure-state fidelity

\[F_{ij} = \frac{|\langle \psi_i | \psi_j \rangle|^2} {\langle \psi_i | \psi_i \rangle\,\langle \psi_j | \psi_j \rangle}.\]

More generally (for machine_pow != 2), the returned quantity is still a consistent expectation value under the chosen sampling distributions but is not equal to the Hilbert-space fidelity unless machine_pow == 2.

MPI semantics

The value and statistics are computed using netket.jax.expect(), which internally uses MPI-aware reductions for statistics.

For gradients, note that netket.jax.expect() uses an MPI mean inside its custom VJP rule. As a consequence, gradients produced by automatic differentiation are scaled by 1 / n_ranks on each rank. To obtain rank-independent gradients suitable for parameter updates, this function performs an MPI sum over ranks on the AD gradients.

type apply_fun_i:

param apply_fun_i:

Apply function for state i mapping variables and samples to log(psi).

type apply_fun_j:

param apply_fun_j:

Apply function for state j mapping variables and samples to log(psi).

type machine_pow_i:

param machine_pow_i:

Sampling power used for state i in the joint log density.

type machine_pow_j:

param machine_pow_j:

Sampling power used for state j in the joint log density.

type params_i:

param params_i:

Parameter pytree for state i.

type model_state_i:

param model_state_i:

Model-state pytree for state i (everything except parameters).

type params_j:

param params_j:

Parameter pytree for state j.

type model_state_j:

param model_state_j:

Model-state pytree for state j (everything except parameters).

type sigma_i:

param sigma_i:

Samples from state i (may include chain dimensions).

type sigma_j:

param sigma_j:

Samples from state j (may include chain dimensions).

return:

(fid_val, fid_stats, grads) where fid_val is a real scalar, fid_stats is a NetKet statistics object, and grads is a dict-like pytree with keys "i" and "j" containing gradients for state i and state j.

class MultiStateVMC(variational_state, hamiltonian, optimizer, *, preconditioner=<netket.optimizer.preconditioner.IdentityPreconditioner object>, lambda_ortho=1.0)

Bases: VMC

VMC driver for MultiMCState with an optional orthogonality penalty.

For each contained state, this driver computes energy statistics and gradients of the Hamiltonian objective. If lambda_ortho is nonzero and more than one state is present, it additionally applies a pairwise penalty based on a fidelity-like overlap estimator to encourage the states to become mutually orthogonal.

Preconditioning is applied independently per state. Interpreted as a joint optimisation problem over all parameters, this corresponds to a block-diagonal preconditioner.

Parameters:
  • variational_state (MultiMCState) – Multi-state variational object containing multiple independent states.

  • hamiltonian (Union[AbstractOperator, list]) – Operator or list of operators defining the objective for each state.

  • optimizer (Any) – Optimiser used to update parameters from the (preconditioned) gradients.

  • preconditioner (Callable[[VariationalState, Any, Any | None], Any]) – Preconditioner applied per state to transform raw gradients.

  • lambda_ortho (float) – Strength of the orthogonality penalty. Set to 0 to disable.

Raises:

TypeError – If any operator acts on a Hilbert space incompatible with the variational state.

property states: list[MCState]

Return the list of underlying per-state MCState instances.

This is a convenience view into the states managed by the multi-state variational state.

Returns:

The list of contained Monte Carlo variational states.

estimate(observables)

Estimate observables and flatten per-state results into a logging-friendly dictionary.

This overrides the base driver’s estimate() method to return a flat mapping from string keys to scalar statistics objects or values.

For each observable leaf that yields a list of per-state statistics, entries are created as:

"<name> (state i)" -> Stats

Optionally, a simple aggregate curve is added as "<name>/sum_mean" containing the sum of per-state means when available. This aggregate is intended for quick diagnostics and is not a statistically rigorous combination of per-state uncertainties.

Parameters:

observables – Observable or pytree of observables to estimate. If None, an empty set of observables is assumed.

Returns:

Flat dictionary mapping names to per-state statistics objects or scalar values.