neuralqx.solver.solver module¶
Concrete solver implementation for neuraLQX Variational Monte Carlo workflows.
This module provides a production-ready Solver that implements the abstract solver API
defined by AbstractSolver. It orchestrates the full
optimisation lifecycle around an LQX model:
configuration of sampler, optimiser (and SR preconditioner), and neural network,
construction of a Monte-Carlo variational state (
MCState),construction and execution of a VMC driver (
VMC),robust finalisation (checkpoint export, logging, plotting, and live monitoring shutdown),
safe continuation from a previous checkpoint (including MPI-aware state broadcast).
Run structure¶
The main optimisation entry point (Solver.run()) is intentionally split into three phases:
Preparation (
Solver._prepare_run()) - Ensures driver/state initialisation for fresh runs - Normalises callbacks and attaches optional live monitoring - Records observables to be evaluated during optimisationExecution (
Solver._execute_run()) - Delegates to the underlying VMC driver loop - Tracks performed iteration count when availableFinalisation (
Solver._finalise_run()) - Mirrors driver state back into solver state - Computes robust final observable/constraint estimates - Logs trailing averages and exports checkpoints - Plots and serialises results (rank-0) and synchronises ranks
MPI and interrupts¶
This implementation treats Ctrl+C differently depending on the MPI context:
Single-rank mode: Ctrl+C triggers a graceful abort that still finalises, exports an “aborted” marker checkpoint, and attempts to plot/log safely.
Multi-rank mode: Ctrl+C triggers a hard MPI abort to avoid deadlocks from rank-skewed interrupts during collectives. A lightweight abort marker file is written by rank-0 when possible.
Serialisation¶
Export/import uses an MCState-level serialisation format and ensures that all MPI ranks synchronise around serialisation steps. This prevents situations where JAX device computations (including MPI-backed collectives) would otherwise block serialisation on some ranks.
- reject_outliers(data, m=2.0)¶
Reject statistical outliers from a 1D data array using a robust median absolute deviation filter.
This helper is primarily used for plotting (e.g. determining inset y-limits) where a handful of extreme values can make the view uninformative. The method computes:
the median of the data,
the median absolute deviation (MAD),
a scaled deviation score for each entry,
and returns only those values whose scaled deviation is below a threshold
m.Notes
This filter is robust to heavy-tailed distributions and is less sensitive than mean/std-based rules.
If MAD is zero (all values identical), the function falls back to a denominator of 1.0 to avoid division by zero and will keep all points.
- Parameters:
data – 1D array-like of numeric values to be filtered.
m (
float) – Outlier rejection threshold in units of MAD. Larger values keep more points.
- Returns:
A NumPy array containing only the inlier values.
- Raises:
ValueError – If
datacannot be converted to a numeric NumPy array.
- class Solver(lqx, output_path=None, auxiliary_path=None, *, clean_up=False, seed=None)¶
Bases:
AbstractSolverConcrete neuraLQX solver implementing a VMC-based optimisation workflow.
This class is the main user-facing solver implementation that executes variational optimisation for an LQX model. It inherits core wiring, output/log bookkeeping, and safety flags from
AbstractSolverand implements the operational API:set_sampler()builds and attaches a NetKet-compatible sampler (including specialised gauge samplers when supported by the Hilbert space).set_optimizer()builds an Optax optimiser and configures stochastic reconfiguration (SR), including the SR linear solver backend.set_network()installs a neural network Ansatz and optionally wraps it with symmetry/group projection when running diffeomorphism-invariant simulations.initialize_vmc()constructs theMCStateand theVMCdriver, including SR preconditioning.run()executes an optimisation loop with robust interruption handling, finalisation, state export, and plotting.continue_simulation()resumes from an in-memory or on-disk checkpoint while preserving runtime logs and ensuring driver/state consistency.
Key design goals¶
Robustness and reproducibility:
Every solver run has a unique hash and a deterministic seed (unless user overrides).
State export/import is MPI-aware and uses broadcast to ensure all ranks reconstruct the same state.
Finalisation is defensive: I/O and plotting failures do not crash the solver after a successful run; they are treated as best-effort.
Clarity of run semantics:
The run pipeline is explicitly split into preparation, execution, and finalisation phases.
Continuations do not implicitly reset logs or state unless explicitly requested.
MPI-safe interruption:
Single-rank runs attempt graceful abort and still export a final checkpoint.
Multi-rank runs hard-abort on Ctrl+C to avoid deadlocks from rank-skewed collectives.
- type lqx:
- param lqx:
The model interface implementing
AbstractLqxInterface.- type output_path:
- param output_path:
Optional base directory for solver outputs.
- type auxiliary_path:
- param auxiliary_path:
Optional grouping component appended under
output_path.- type clean_up:
- param clean_up:
If True, remove empty stale directories before starting.
- type seed:
- param seed:
Optional explicit seed controlling solver reproducibility.
- return:
None.
- raises ValueError:
If
output_pathis provided but invalid.- raises OSError:
If output directories cannot be created.
- run(n_iters, *, silent_print=False, silent_plot=True, timer=False, live_monitoring=False, callbacks=<function Solver.<lambda>>, observables=None, **kwargs)¶
Run an optimisation/minimisation simulation.
This method is the primary optimisation entry point. The implementation is structured into three phases to keep the execution path explicit and robust:
Preparation (
_prepare_run()) - For fresh runs, ensures VMC driver/state are initialised. - Normalises callbacks into a list. - Optionally attaches a live monitoring callback. - Stores the observable registry and logs basic run metadata.Execution (
_execute_run()) - Calls the underlying VMC driver loop with consistent parameters. - Tracks performed iteration count when available.Finalisation (
_finalise_run()) - Mirrors driver state back into solver state. - Computes final expectation/constraint estimates. - Logs trailing statistics, exports a checkpoint, and plots results.
Interrupt handling¶
KeyboardInterrupt (Ctrl+C):
Single-rank: marks the run as aborted, writes an abort marker, and still finalises safely.
Multi-rank: writes a lightweight abort marker from rank-0 when possible and then hard-aborts the MPI job to prevent deadlocks.
Other exceptions:
Marks run as aborted and writes a lightweight abort marker on rank-0 when possible.
Re-raises the exception to keep the error visible.
- type n_iters:
- param n_iters:
Number of optimisation iterations to execute.
- type silent_print:
- param silent_print:
If True, suppress user-facing printing during finalisation.
- type silent_plot:
- param silent_plot:
If True, suppress interactive plot display (figures may still be saved).
- type timer:
- param timer:
If True, enable timing/profiling integration in the driver loop.
- type live_monitoring:
- param live_monitoring:
If True, attach a live monitoring callback for real-time feedback.
- type callbacks:
- param callbacks:
Callback or list of callbacks invoked after each iteration. Each callback must accept
(iteration_index, log_dict, driver)and return a boolean. Returning False requests early termination.- type observables:
dict[str,AbstractOperator] |None- param observables:
Optional mapping of observable names to operators/observables to evaluate during optimisation.
- type kwargs:
- param kwargs:
Implementation-specific run options. Internal continuation mode may set
_ct=True.- rtype:
- return:
None.
- raises RuntimeError:
If required solver components are not initialised and cannot be initialised.
- raises Exception:
Propagates any unexpected exception raised by the driver loop or callbacks.
- continue_simulation(*, state_path=None, n_iters=500, silent_print=False, silent_plot=False, timer=False, live_monitoring=False, callbacks=<function Solver.<lambda>>, observables=None, force_load_mpi=False, **kwargs)¶
Continue a previously started simulation by resuming from memory or from a serialised checkpoint.
This method supports two continuation sources:
Resume from disk (
state_pathprovided):Ensures a template variational state/driver exists (needed for safe deserialisation).
Imports a full
MCStatefrom disk in an MPI-safe manner.Installs the loaded state into the solver and updates the driver reference.
Resume from memory (
state_pathis None):If a driver exists, reuse it.
Otherwise, initialise the driver/state using current solver configuration.
Log semantics¶
The runtime log is preserved across continuations unless it is missing, in which case a new log is created. Observables default to the previously registered set unless explicitly overridden.
- type state_path:
- param state_path:
Optional path to a previously exported MCState checkpoint.
- type n_iters:
- param n_iters:
Number of additional optimisation iterations to execute.
- type silent_print:
- param silent_print:
If True, suppress printing of final log output.
- type silent_plot:
- param silent_plot:
If True, suppress interactive plot display.
- type timer:
- param timer:
If True, enable timing/profiling integration in the driver.
- type live_monitoring:
- param live_monitoring:
If True, attach live monitoring callback for this continuation run.
- type callbacks:
- param callbacks:
Callback or list of callbacks invoked after each iteration.
- type observables:
dict[str,AbstractOperator] |None- param observables:
Optional observable registry for this continuation run. If None, reuse the last registered observables (if any).
- type force_load_mpi:
- param force_load_mpi:
If True, allow best-effort reconstruction even if MPI configuration differs from the environment used when saving.
- type kwargs:
- param kwargs:
Implementation-specific options. May include
diagonal_shiftoverride.- rtype:
- return:
None.
- raises RuntimeError:
If continuation is requested but solver components are not initialised and cannot be initialised.
- raises OSError:
If
state_pathis provided but the file cannot be read.- raises ValueError:
If the checkpoint is invalid or incompatible with the current environment.
- export_state(*, state=None, silent=False, marker='', **kwargs)¶
Export (serialise) a complete Monte Carlo variational state to disk.
This method exports the full
MCState(including sampler configuration/state) so that a run can be reproduced even if the checkpoint was taken before convergence.MPI synchronisation rationale¶
Serialisation converts JAX arrays into host buffers. This conversion blocks until all outstanding device work is complete, which may include MPI-backed collectives issued during the last VMC step. To avoid deadlocks and ensure all ranks reach a consistent point:
All ranks serialise and synchronise on an MPI barrier.
Rank-0 writes the checkpoint to disk.
All ranks synchronise again before returning.
File naming¶
The exported filename includes:
the solver hash,
a timestamp,
an optional
markersuffix (e.g. “FinalState”, “AbortedState”).
- type state:
- param state:
Optional explicit state to export. If None, export the solver’s current state.
- type silent:
- param silent:
If True, suppress user-facing confirmation messages.
- type marker:
- param marker:
Optional label appended to the exported filename for easy identification.
- type kwargs:
- param kwargs:
Implementation-specific export options (currently unused).
- rtype:
- return:
The exported file path if the implementation chooses to return it; otherwise
None.- raises RuntimeError:
If no variational state exists and
stateis None.- raises OSError:
If the checkpoint cannot be written to disk (rank-0).
- import_state(state_path, *, force_load_mpi=False)¶
Import and deserialise a variational state from a checkpoint file.
This method loads a full serialised
MCStatein an MPI-safe way:Rank-0 reads the checkpoint bytes from disk.
The raw checkpoint payload is broadcast to all ranks.
Each rank reconstructs a new MCState using
deserialize_MCState().
Template requirement¶
Deserialisation may rely on a “template” MCState for shape/dtype/model reconstruction. This implementation expects that
self.variational_stateexists when calling this method. Continuation logic ensures a template is created before import if needed.- type state_path:
- param state_path:
Path to a previously exported checkpoint file.
- type force_load_mpi:
- param force_load_mpi:
If True, force best-effort reconstruction even if MPI configuration differs from the environment used when saving.
- rtype:
- return:
A reconstructed
MCStateinstance.- raises OSError:
If the checkpoint file cannot be read on rank-0.
- raises RuntimeError:
If a required template state is not available for deserialisation.
- raises ValueError:
If the checkpoint payload is invalid or incompatible.
- plot_results(*, silent_plot=False, with_inset=True, **kwargs)¶
Plot and save the main minimisation curve and export auxiliary run artefacts.
This method produces: - a plot of the constraint expectation value versus iteration with error bars, - an optional inset zoom into the final iterations (useful for convergence inspection), - a saved image of the underlying graph when supported by the graph object, - a serialised copy of the runtime log data to disk.
Plot content¶
The primary curve uses the runtime log series under
"Constraint":x-axis: iteration index
y-axis: mean value
error bars: sigma
If the model provides an exact diagonalisation ground energy, a horizontal reference line is drawn.
Inset behaviour¶
When
with_inset=True, the method:zooms into the last ~20 iterations,
sanitises data via
reject_outliers()for stable y-limits,draws an inset with matching content and a marked region in the main plot.
- type silent_plot:
- param silent_plot:
If True, suppress interactive plot display (figures are still saved).
- type with_inset:
- param with_inset:
If True, include an inset zoom of the final part of the trajectory.
- type kwargs:
- param kwargs:
Implementation-specific plotting options (currently unused).
- rtype:
- return:
None.
- raises KeyError:
If expected log keys (e.g.
"Constraint") are missing from the runtime log.- raises OSError:
If output images or log serialisation cannot be written to disk.
- plot_observables(*observables, **kwargs)¶
Plot one or more logged observables from the runtime log.
For each requested observable key, this method attempts to read:
Meanseries,Sigmaseries,
and produces a plot of mean versus iteration with a shaded mean±sigma band.
Observable lookup¶
Observable data is retrieved from the solver runtime log using the provided keys. If an observable is missing, a ValueError is raised with a user-facing message.
- set_sampler(sampler_type, *, number_of_chains=16, number_of_sweeps=20, machine_power=2, reset_chains=True, number_of_samples=512, **kwargs)¶
Configure and attach a sampler to the solver.
This method builds a sampler using the neuraLQX sampler factory and stores:
the primary sampler used by the variational state,
an auxiliary sampler copy (when provided by the factory),
the sampler initialisation kwargs for later reconstruction/logging,
the number of samples requested per chain.
Gauge sampler compatibility¶
If the requested sampler name indicates a gauge sampler (substring “gauge”), this method enforces that the underlying Hilbert space is gauge invariant. Otherwise it raises NotImplementedError to avoid silent misuse.
Logging¶
Sampler metadata is exported to the solver logger. Sampler kwargs are also recorded under “Sampler Configs”, excluding keys that are already covered by the sampler metadata.
- type sampler_type:
- param sampler_type:
Human-readable sampler identifier supported by neuraLQX.
- type number_of_chains:
- param number_of_chains:
Total number of Monte Carlo chains to use.
- type number_of_sweeps:
- param number_of_sweeps:
Number of sweeps per sampling step.
- type machine_power:
- param machine_power:
Exponent used in the sampling probability distribution (backend-specific).
- type reset_chains:
- param reset_chains:
If True, reset chains at each iteration (backend-specific).
- type number_of_samples:
- param number_of_samples:
Number of samples per chain.
- type kwargs:
- param kwargs:
Sampler-specific parameters (e.g.
d_max,rules,probabilities,Hamiltonian, temperature ladder, etc.).- rtype:
- return:
None.
- raises NotImplementedError:
If a gauge sampler is requested for a non-gauge-invariant Hilbert space.
- raises ValueError:
If sampler configuration parameters are invalid for the chosen sampler type.
- set_optimizer(optimizer_type='Adam', *, use_sr, preconditioner_solver='Conjugate Gradient', preconditioner=None, learning_rate=None, scheduler_type=None, scheduler_parameters=None, diagonal_shift=None, **kwargs)¶
Configure and attach an optimiser and the SR preconditioner linear solver.
This method constructs:
an Optax optimiser via the neuraLQX
Optimizerbuilder,a linear solver backend for stochastic reconfiguration (SR) via
Solvers,and stores a diagonal shift value used later when constructing the SR preconditioner.
Logging¶
Optimiser metadata and solver metadata are exported to the solver logger under “Optimizer Configs”.
- type optimizer_type:
- param optimizer_type:
Name of the optimiser supported by neuraLQX (e.g. “Adam”).
- type preconditioner_solver:
- param preconditioner_solver:
Name of the linear solver used inside SR (e.g. “Conjugate Gradient”).
- type preconditioner:
AbstractLinearPreconditioner|None- param preconditioner:
The preconditioner object used in the natural gradient descent, if not specified SR is used by default.
- type learning_rate:
- param learning_rate:
Optional constant learning rate. If None, a sensible default is chosen by the optimiser builder.
- type scheduler_type:
- param scheduler_type:
Optional learning-rate schedule identifier.
- type scheduler_parameters:
- param scheduler_parameters:
Optional mapping of schedule parameters passed to the schedule builder.
- type diagonal_shift:
- param diagonal_shift:
Optional diagonal shift regularisation used by SR. If None, defaults to 0.1.
- type kwargs:
- param kwargs:
Additional optimiser-specific configuration parameters forwarded to the builder.
- type use_sr:
- param use_sr:
if True, a stochastic reconfiguration (SR) preconditioner will be used
- rtype:
- return:
None.
- raises ValueError:
If optimiser or schedule configuration is invalid.
- raises RuntimeError:
If the optimiser or solver backend cannot be constructed.
- set_network(network, *, diff_invariant=False, symmetries=None, chunk_size=None, **kwargs)¶
Configure and attach a neural network Ansatz to the solver.
This method:
records network type and attributes to the solver logger,
stores a copy of the original network for internal use when needed,
optionally wraps the network with a group/symmetry projector when running a diffeomorphism invariant simulation,
stores the final network module and updates internal flags.
Diffeomorphism-invariant mode¶
If
diff_invariant=True, the solver marks the run as diffeomorphism invariant and wraps the network usingwrap_model(). The wrapper applies a group-averaging/projection procedure over the provided symmetry descriptors.- type network:
- param network:
Neural network module (typically a subclass of
flax.linen.Module).- type diff_invariant:
- param diff_invariant:
If True, enable diffeomorphism-invariant (group-averaged) simulation mode.
- type symmetries:
- param symmetries:
Optional symmetry descriptors used by the projection wrapper.
- type kwargs:
- param kwargs:
Implementation-specific parameters (currently unused).
- type chunk_size:
- param chunk_size:
If specified, the expectation values and gradients are computed in a chunked manner where the largest chunk has size
chunk_size- rtype:
- return:
None.
- raises ValueError:
If symmetry configuration is inconsistent with the graph/model.
- raises RuntimeError:
If the network wrapper cannot be constructed.
- expect(operator, *, n_samples=None, n_chains=None, use_same_variables=True, use_same_sampler=True, **kwargs)¶
Compute the expectation value of an operator or list of operators.
This method provides explicit sampling semantics to avoid accidental misuse.
Supported modes¶
Alias mode (default): If only
operatoris provided, this is a direct alias tovariational_state.expect(operator).Alias mode with temporary sample-count override: If
n_samplesis provided andn_chainsis not provided, this method temporarily overridesvariational_state.n_samples, warns the user that samples will be regenerated, evaluates the expectation, and restores the original sample count.
Unsupported mode (planned)¶
If
n_chainsis provided, the intended behaviour is to build a temporary state with a modified sampler and draw fresh samples. This mode is not implemented and currently raises NotImplementedError.Ignored arguments¶
In alias modes, arguments that would only apply to “fresh sampling” (such as
use_same_variables,use_same_sampler, and additional kwargs) are ignored with a warning.- type operator:
Union[list[Union[AbstractOperator,AbstractObservable]],AbstractOperator,AbstractObservable]- param operator:
Operator/observable or a list of operators/observables to evaluate.
- type n_samples:
- param n_samples:
Optional number of samples per chain (alias mode override).
- type n_chains:
- param n_chains:
Optional number of chains (fresh sampling request; not implemented).
- type use_same_variables:
- param use_same_variables:
Intended for fresh sampling mode; ignored in alias modes.
- type use_same_sampler:
- param use_same_sampler:
Intended for fresh sampling mode; ignored in alias modes.
- type kwargs:
- param kwargs:
Additional options; ignored in alias modes.
- rtype:
Stats- return:
A Stats-like object containing the estimated expectation value(s).
- raises RuntimeError:
If the variational state is not initialised.
- raises NotImplementedError:
If
n_chainsis provided (fresh sampling mode not implemented).
- initialize_vmc(*args, **kwargs)¶
Initialise the variational state and VMC driver.
This method constructs the objects required to execute optimisation:
Validates configuration: Requires that sampler, optimiser, and network have been configured via:
Builds the Monte Carlo variational state: Constructs
MCStateusing the configured sampler and network and stores it in the solver. It also records the number of trainable parameters and the ratio of parameter count to Hilbert dimension (when estimable).Builds the VMC driver: Constructs
VMCusing the model constraint as the optimisation target, the configured optimiser, and an SR preconditioner:diagonal shift taken from
diagonal_shift,holomorphic flag inferred via NetKet’s holomorphic diagnostic,
linear solver backend taken from
preconditioner_solver.
- Parameters:
args – Implementation-specific initialisation options (currently unused).
kwargs – Implementation-specific initialisation options (may be ignored by this implementation).
- Return type:
- Returns:
None.
- Raises:
PermissionError – If sampler, optimiser, or network have not been configured.
RuntimeError – If the variational state or driver cannot be constructed.