neuralqx.experimental.driver.mvmc module¶
Multi-state VMC driver with an optional orthogonality regularisation term.
This module extends a standard variational Monte Carlo (VMC) optimisation loop to train
multiple independent MCState instances jointly via a
MultiMCState container.
At each optimisation step the driver computes an energy gradient for every state. When enabled, an additional pairwise penalty based on a fidelity-like overlap estimator is added to encourage mutual orthogonality between different states.
A dedicated JAX kernel is provided to estimate the pairwise overlap quantity and its gradients with respect to both states’ parameters using a shared joint sample batch.
- fidelity_expect_and_grad_joint(apply_fun_i, apply_fun_j, machine_pow_i, machine_pow_j, params_i, model_state_i, params_j, model_state_j, sigma_i, sigma_j)¶
Estimate a fidelity-like overlap quantity and its gradients for two variational states.
This JIT-compiled kernel forms a joint configuration batch by concatenating samples from two states and evaluates a local ratio estimator on crossed configurations. The result is a scalar overlap measure (real-valued) together with auxiliary statistics and gradients with respect to both parameter sets.
For the common choice
machine_pow == 2(Born sampling), the expectation value matches the normalised pure-state fidelity\[F_{ij} = \frac{|\langle \psi_i | \psi_j \rangle|^2} {\langle \psi_i | \psi_i \rangle\,\langle \psi_j | \psi_j \rangle}.\]More generally (for
machine_pow != 2), the returned quantity is still a consistent expectation value under the chosen sampling distributions but is not equal to the Hilbert-space fidelity unlessmachine_pow == 2.MPI semantics¶
The value and statistics are computed using
netket.jax.expect(), which internally uses MPI-aware reductions for statistics.For gradients, note that
netket.jax.expect()uses an MPI mean inside its custom VJP rule. As a consequence, gradients produced by automatic differentiation are scaled by1 / n_rankson each rank. To obtain rank-independent gradients suitable for parameter updates, this function performs an MPI sum over ranks on the AD gradients.- type apply_fun_i:
- param apply_fun_i:
Apply function for state i mapping variables and samples to
log(psi).- type apply_fun_j:
- param apply_fun_j:
Apply function for state j mapping variables and samples to
log(psi).- type machine_pow_i:
- param machine_pow_i:
Sampling power used for state i in the joint log density.
- type machine_pow_j:
- param machine_pow_j:
Sampling power used for state j in the joint log density.
- type params_i:
- param params_i:
Parameter pytree for state i.
- type model_state_i:
- param model_state_i:
Model-state pytree for state i (everything except parameters).
- type params_j:
- param params_j:
Parameter pytree for state j.
- type model_state_j:
- param model_state_j:
Model-state pytree for state j (everything except parameters).
- type sigma_i:
- param sigma_i:
Samples from state i (may include chain dimensions).
- type sigma_j:
- param sigma_j:
Samples from state j (may include chain dimensions).
- return:
(fid_val, fid_stats, grads)wherefid_valis a real scalar,fid_statsis a NetKet statistics object, andgradsis a dict-like pytree with keys"i"and"j"containing gradients for state i and state j.
- class MultiStateVMC(variational_state, hamiltonian, optimizer, *, preconditioner=<netket.optimizer.preconditioner.IdentityPreconditioner object>, lambda_ortho=1.0)¶
Bases:
VMCVMC driver for
MultiMCStatewith an optional orthogonality penalty.For each contained state, this driver computes energy statistics and gradients of the Hamiltonian objective. If
lambda_orthois nonzero and more than one state is present, it additionally applies a pairwise penalty based on a fidelity-like overlap estimator to encourage the states to become mutually orthogonal.Preconditioning is applied independently per state. Interpreted as a joint optimisation problem over all parameters, this corresponds to a block-diagonal preconditioner.
- Parameters:
variational_state (
MultiMCState) – Multi-state variational object containing multiple independent states.hamiltonian (
Union[AbstractOperator,list]) – Operator or list of operators defining the objective for each state.optimizer (
Any) – Optimiser used to update parameters from the (preconditioned) gradients.preconditioner (
Callable[[VariationalState,Any,Any|None],Any]) – Preconditioner applied per state to transform raw gradients.lambda_ortho (
float) – Strength of the orthogonality penalty. Set to0to disable.
- Raises:
TypeError – If any operator acts on a Hilbert space incompatible with the variational state.
- property states: list[MCState]¶
Return the list of underlying per-state
MCStateinstances.This is a convenience view into the states managed by the multi-state variational state.
- Returns:
The list of contained Monte Carlo variational states.
- estimate(observables)¶
Estimate observables and flatten per-state results into a logging-friendly dictionary.
This overrides the base driver’s
estimate()method to return a flat mapping from string keys to scalar statistics objects or values.For each observable leaf that yields a list of per-state statistics, entries are created as:
"<name> (state i)" -> StatsOptionally, a simple aggregate curve is added as
"<name>/sum_mean"containing the sum of per-state means when available. This aggregate is intended for quick diagnostics and is not a statistically rigorous combination of per-state uncertainties.- Parameters:
observables – Observable or pytree of observables to estimate. If
None, an empty set of observables is assumed.- Returns:
Flat dictionary mapping names to per-state statistics objects or scalar values.