neuralqx.experimental.solver.multi_state_solver module

Multi state solver utilities for training and monitoring multiple variational states.

This module provides serialization helpers for MultiMCState and a Solver subclass that trains several MCState instances jointly using MultiStateVMC. It supports per state expectation evaluation, multi state export and import, per state logging of final constraint estimates, and plotting of per state optimisation curves when present in the runtime log.

serialize_MultiMCState(vstate)

Serialize a MultiMCState into a plain Python dictionary.

The resulting dictionary is suitable for writing to disk with the neuraLQX serialization utilities. Each contained MultiMCState is serialized using serialize_MCState.

Parameters:

vstate (MultiMCState) – Multi state variational state to serialize.

Return type:

dict

Returns:

Dictionary representation of the multi state, including per state payloads.

deserialize_MultiMCState(template, state_dict, *, force_load_mpi=False)

Deserialize a MultiMCState from a dictionary into a new MultiMCState instance.

The function uses an existing template multi state to provide structure and implementation details required by deserialize_MCState. For backward compatibility, a dictionary that represents a single MCState is accepted and wrapped into a single state MultiMCState.

Parameters:
  • template (MultiMCState) – Existing multi state used as a template for reconstructing each state.

  • state_dict (dict) – Dictionary payload previously produced by serialization utilities.

  • force_load_mpi (bool) – If True, attempt to load MPI related fields even if they differ from the current run configuration.

Return type:

MultiMCState

Returns:

Newly reconstructed :class`~neuralqx.experimental.vqs.MultiMCState`.

Raises:

ValueError – If the number of saved states does not match template.n_states.

class MultiSolver(lqx, output_path=None, auxiliary_path=None, *, clean_up=False, seed=None)

Bases: Solver

Solver subclass that trains multiple variational states jointly.

This solver is a drop in variant of Solver that runs MultiStateVMC with a MultiMCState backend. It supports configuring multiple networks, building one MCState per network during VMC initialization, evaluating expectations per state or for all states, exporting and importing multi state checkpoints, and logging final constraint results both aggregated and per state.

Returns:

None.

property lambda_ortho: float

Return the orthogonalization strength used between the networks.

Returns:

Orthogonalization factor used by the multi state driver.

set_network(network, *, diff_invariant=False, symmetries=None, lambda_ortho=None, **kwargs)

Configure one or more networks for a multi state simulation.

A single model or a sequence of models can be provided. The solver stores the models as a list and logs both an aggregate network summary and per state network type entries. When diffeomorphism invariance is enabled, each model is wrapped using the configured symmetry information.

Parameters:
  • network (Union[Module, Sequence[Module]]) – A single Flax module or a sequence of Flax modules, one per state.

  • diff_invariant (bool) – If True, enable diffeomorphism invariant wrapping of the models.

  • symmetries (Optional[Sequence[Any]]) – Optional symmetry data forwarded to the model wrapper.

  • lambda_ortho (Optional[float]) – Optional override for the orthogonalization strength used during multi state optimisation.

  • kwargs – Additional keyword arguments accepted for forward compatibility.

Return type:

None

Returns:

None.

expect(operator, *, state_idx=None, n_samples=None, n_chains=None, use_same_variables=True, use_same_sampler=True, **kwargs)

Compute expectation values in a multi state aware way.

If state_idx is None, compute expectations for all states and return a list of Stats. If state_idx is an integer, compute the expectation for that single state and return a single Stats object. An optional n_samples can be provided to temporarily override the number of samples used by the selected state or states, which discards and regenerates Monte Carlo samples.

Parameters:
  • operator (Union[List[Union[AbstractOperator, AbstractObservable]], AbstractOperator, AbstractObservable]) – Operator or observable, or a list of them, to evaluate.

  • state_idx (Optional[int]) – Optional state index to select a single state. If None, evaluate all states.

  • n_samples (Optional[int]) – Optional temporary sample count override for the evaluation.

  • n_chains (Optional[int]) – Not supported in this method.

  • use_same_variables (bool) – Accepted for API compatibility and ignored.

  • use_same_sampler (bool) – Accepted for API compatibility and ignored.

  • kwargs – Additional keyword arguments accepted for API compatibility and ignored.

Return type:

Union[Stats, List[Stats]]

Returns:

A Stats object if state_idx is provided, otherwise a list of Stats.

Raises:
initialize_vmc(*args, **kwargs)

Initialise the multi state variational state and VMC driver.

This method requires that sampler, optimizer, and networks have already been set. It builds one MCState per network with distinct seeds, collects holomorphicity flags, constructs a MultiMCState, logs parameter counts and ratios, and then creates a MultiStateVMC driver with a per state SR preconditioner.

Parameters:
  • args – Positional arguments accepted for compatibility.

  • kwargs – Keyword arguments. If present, lambda_ortho overrides the stored orthogonalization factor.

Return type:

None

Returns:

None.

Raises:
  • PermissionError – If sampler, optimizer, or networks are not configured.

  • ValueError – If no networks are available for initialization.

export_state(*, state=None, silent=False, marker='', **kwargs)

Export a MultiMCState checkpoint to disk in an MPI safe manner.

The method serializes each contained MCState and writes a single multi state payload to the solver output directory on the global MPI master rank. Other ranks participate via barriers and perform no file IO.

Parameters:
  • state (Optional[MultiMCState]) – Optional MultiMCState to export. If None, exports the current solver state.

  • silent (bool) – If True, suppress user facing printing on successful export.

  • marker (str) – Optional string appended to the filename for easier identification.

  • kwargs – Additional keyword arguments accepted for forward compatibility.

Return type:

Optional[str]

Returns:

None.

Raises:

TypeError – If the provided or current state is not a MultiMCState.

import_state(state_path, *, force_load_mpi=False)

Import a MultiMCState checkpoint from disk.

The payload is loaded on the global MPI master rank and broadcast to all ranks. This method requires that a template MultiMCState already exists, which is created by calling initialize_vmc after configuring networks, sampler, and optimiser.

Parameters:
  • state_path (str) – Path to the serialized checkpoint file.

  • force_load_mpi (bool) – If True, attempt to load MPI related fields even if they differ from the current run configuration.

Return type:

MultiMCState

Returns:

Reconstructed MultiMCState instance.

Raises:

RuntimeError – If no MultiMCState template exists in the solver.

plot_results(*, silent_plot=False, with_inset=True, **kwargs)

Plot and save the minimisation curves for the simulation.

This method exports a graph image when possible, then plots the objective curve with error bars. If multi state series are present in the runtime log, it plots one curve per state. Otherwise it falls back to a single objective curve. An optional inset view can be added for the final iterations, and the runtime log is serialized alongside the saved plot image.

Parameters:
  • silent_plot (bool) – If True, do not display the plot interactively.

  • with_inset (bool) – If True, include an inset showing the final portion of the curves.

  • kwargs – Additional keyword arguments accepted for forward compatibility.

Return type:

None

Returns:

None.