neuralqx.experimental.solver.multi_state_solver module¶
Multi state solver utilities for training and monitoring multiple variational states.
This module provides serialization helpers for MultiMCState and a Solver subclass that trains several MCState instances jointly using MultiStateVMC. It supports per state expectation evaluation, multi state export and import, per state logging of final constraint estimates, and plotting of per state optimisation curves when present in the runtime log.
- serialize_MultiMCState(vstate)¶
Serialize a
MultiMCStateinto a plain Python dictionary.The resulting dictionary is suitable for writing to disk with the neuraLQX serialization utilities. Each contained
MultiMCStateis serialized usingserialize_MCState.- Parameters:
vstate (
MultiMCState) – Multi state variational state to serialize.- Return type:
- Returns:
Dictionary representation of the multi state, including per state payloads.
- deserialize_MultiMCState(template, state_dict, *, force_load_mpi=False)¶
Deserialize a
MultiMCStatefrom a dictionary into a newMultiMCStateinstance.The function uses an existing template multi state to provide structure and implementation details required by deserialize_MCState. For backward compatibility, a dictionary that represents a single MCState is accepted and wrapped into a single state
MultiMCState.- Parameters:
template (
MultiMCState) – Existing multi state used as a template for reconstructing each state.state_dict (
dict) – Dictionary payload previously produced by serialization utilities.force_load_mpi (
bool) – If True, attempt to load MPI related fields even if they differ from the current run configuration.
- Return type:
- Returns:
Newly reconstructed :class`~neuralqx.experimental.vqs.MultiMCState`.
- Raises:
ValueError – If the number of saved states does not match template.n_states.
- class MultiSolver(lqx, output_path=None, auxiliary_path=None, *, clean_up=False, seed=None)¶
Bases:
SolverSolver subclass that trains multiple variational states jointly.
This solver is a drop in variant of
Solverthat runsMultiStateVMCwith aMultiMCStatebackend. It supports configuring multiple networks, building oneMCStateper network during VMC initialization, evaluating expectations per state or for all states, exporting and importing multi state checkpoints, and logging final constraint results both aggregated and per state.- Returns:
None.
- property lambda_ortho: float¶
Return the orthogonalization strength used between the networks.
- Returns:
Orthogonalization factor used by the multi state driver.
- set_network(network, *, diff_invariant=False, symmetries=None, lambda_ortho=None, **kwargs)¶
Configure one or more networks for a multi state simulation.
A single model or a sequence of models can be provided. The solver stores the models as a list and logs both an aggregate network summary and per state network type entries. When diffeomorphism invariance is enabled, each model is wrapped using the configured symmetry information.
- Parameters:
network (
Union[Module,Sequence[Module]]) – A single Flax module or a sequence of Flax modules, one per state.diff_invariant (
bool) – If True, enable diffeomorphism invariant wrapping of the models.symmetries (
Optional[Sequence[Any]]) – Optional symmetry data forwarded to the model wrapper.lambda_ortho (
Optional[float]) – Optional override for the orthogonalization strength used during multi state optimisation.kwargs – Additional keyword arguments accepted for forward compatibility.
- Return type:
- Returns:
None.
- expect(operator, *, state_idx=None, n_samples=None, n_chains=None, use_same_variables=True, use_same_sampler=True, **kwargs)¶
Compute expectation values in a multi state aware way.
If state_idx is None, compute expectations for all states and return a list of Stats. If state_idx is an integer, compute the expectation for that single state and return a single Stats object. An optional n_samples can be provided to temporarily override the number of samples used by the selected state or states, which discards and regenerates Monte Carlo samples.
- Parameters:
operator (
Union[List[Union[AbstractOperator,AbstractObservable]],AbstractOperator,AbstractObservable]) – Operator or observable, or a list of them, to evaluate.state_idx (
Optional[int]) – Optional state index to select a single state. If None, evaluate all states.n_samples (
Optional[int]) – Optional temporary sample count override for the evaluation.use_same_variables (
bool) – Accepted for API compatibility and ignored.use_same_sampler (
bool) – Accepted for API compatibility and ignored.kwargs – Additional keyword arguments accepted for API compatibility and ignored.
- Return type:
Union[Stats,List[Stats]]- Returns:
A Stats object if state_idx is provided, otherwise a list of Stats.
- Raises:
RuntimeError – If the variational state has not been initialised.
NotImplementedError – If n_chains is provided.
IndexError – If state_idx is out of range for the current number of states.
- initialize_vmc(*args, **kwargs)¶
Initialise the multi state variational state and VMC driver.
This method requires that sampler, optimizer, and networks have already been set. It builds one MCState per network with distinct seeds, collects holomorphicity flags, constructs a MultiMCState, logs parameter counts and ratios, and then creates a MultiStateVMC driver with a per state SR preconditioner.
- Parameters:
args – Positional arguments accepted for compatibility.
kwargs – Keyword arguments. If present, lambda_ortho overrides the stored orthogonalization factor.
- Return type:
- Returns:
None.
- Raises:
PermissionError – If sampler, optimizer, or networks are not configured.
ValueError – If no networks are available for initialization.
- export_state(*, state=None, silent=False, marker='', **kwargs)¶
Export a MultiMCState checkpoint to disk in an MPI safe manner.
The method serializes each contained MCState and writes a single multi state payload to the solver output directory on the global MPI master rank. Other ranks participate via barriers and perform no file IO.
- Parameters:
state (
Optional[MultiMCState]) – Optional MultiMCState to export. If None, exports the current solver state.silent (
bool) – If True, suppress user facing printing on successful export.marker (
str) – Optional string appended to the filename for easier identification.kwargs – Additional keyword arguments accepted for forward compatibility.
- Return type:
- Returns:
None.
- Raises:
TypeError – If the provided or current state is not a MultiMCState.
- import_state(state_path, *, force_load_mpi=False)¶
Import a MultiMCState checkpoint from disk.
The payload is loaded on the global MPI master rank and broadcast to all ranks. This method requires that a template MultiMCState already exists, which is created by calling initialize_vmc after configuring networks, sampler, and optimiser.
- Parameters:
- Return type:
- Returns:
Reconstructed MultiMCState instance.
- Raises:
RuntimeError – If no MultiMCState template exists in the solver.
- plot_results(*, silent_plot=False, with_inset=True, **kwargs)¶
Plot and save the minimisation curves for the simulation.
This method exports a graph image when possible, then plots the objective curve with error bars. If multi state series are present in the runtime log, it plots one curve per state. Otherwise it falls back to a single objective curve. An optional inset view can be added for the final iterations, and the runtime log is serialized alongside the saved plot image.