neuralqx.solver.solver module

Concrete solver implementation for neuraLQX Variational Monte Carlo workflows.

This module provides a production-ready Solver that implements the abstract solver API defined by AbstractSolver. It orchestrates the full optimisation lifecycle around an LQX model:

  • configuration of sampler, optimiser (and SR preconditioner), and neural network,

  • construction of a Monte-Carlo variational state (MCState),

  • construction and execution of a VMC driver (VMC),

  • robust finalisation (checkpoint export, logging, plotting, and live monitoring shutdown),

  • safe continuation from a previous checkpoint (including MPI-aware state broadcast).

Run structure

The main optimisation entry point (Solver.run()) is intentionally split into three phases:

  1. Preparation (Solver._prepare_run()) - Ensures driver/state initialisation for fresh runs - Normalises callbacks and attaches optional live monitoring - Records observables to be evaluated during optimisation

  2. Execution (Solver._execute_run()) - Delegates to the underlying VMC driver loop - Tracks performed iteration count when available

  3. Finalisation (Solver._finalise_run()) - Mirrors driver state back into solver state - Computes robust final observable/constraint estimates - Logs trailing averages and exports checkpoints - Plots and serialises results (rank-0) and synchronises ranks

MPI and interrupts

This implementation treats Ctrl+C differently depending on the MPI context:

  • Single-rank mode: Ctrl+C triggers a graceful abort that still finalises, exports an “aborted” marker checkpoint, and attempts to plot/log safely.

  • Multi-rank mode: Ctrl+C triggers a hard MPI abort to avoid deadlocks from rank-skewed interrupts during collectives. A lightweight abort marker file is written by rank-0 when possible.

Serialisation

Export/import uses an MCState-level serialisation format and ensures that all MPI ranks synchronise around serialisation steps. This prevents situations where JAX device computations (including MPI-backed collectives) would otherwise block serialisation on some ranks.

reject_outliers(data, m=2.0)

Reject statistical outliers from a 1D data array using a robust median absolute deviation filter.

This helper is primarily used for plotting (e.g. determining inset y-limits) where a handful of extreme values can make the view uninformative. The method computes:

  • the median of the data,

  • the median absolute deviation (MAD),

  • a scaled deviation score for each entry,

and returns only those values whose scaled deviation is below a threshold m.

Notes

  • This filter is robust to heavy-tailed distributions and is less sensitive than mean/std-based rules.

  • If MAD is zero (all values identical), the function falls back to a denominator of 1.0 to avoid division by zero and will keep all points.

Parameters:
  • data – 1D array-like of numeric values to be filtered.

  • m (float) – Outlier rejection threshold in units of MAD. Larger values keep more points.

Returns:

A NumPy array containing only the inlier values.

Raises:

ValueError – If data cannot be converted to a numeric NumPy array.

class Solver(lqx, output_path=None, auxiliary_path=None, *, clean_up=False, seed=None)

Bases: AbstractSolver

Concrete neuraLQX solver implementing a VMC-based optimisation workflow.

This class is the main user-facing solver implementation that executes variational optimisation for an LQX model. It inherits core wiring, output/log bookkeeping, and safety flags from AbstractSolver and implements the operational API:

  • set_sampler() builds and attaches a NetKet-compatible sampler (including specialised gauge samplers when supported by the Hilbert space).

  • set_optimizer() builds an Optax optimiser and configures stochastic reconfiguration (SR), including the SR linear solver backend.

  • set_network() installs a neural network Ansatz and optionally wraps it with symmetry/group projection when running diffeomorphism-invariant simulations.

  • initialize_vmc() constructs the MCState and the VMC driver, including SR preconditioning.

  • run() executes an optimisation loop with robust interruption handling, finalisation, state export, and plotting.

  • continue_simulation() resumes from an in-memory or on-disk checkpoint while preserving runtime logs and ensuring driver/state consistency.

Key design goals

Robustness and reproducibility:

  • Every solver run has a unique hash and a deterministic seed (unless user overrides).

  • State export/import is MPI-aware and uses broadcast to ensure all ranks reconstruct the same state.

  • Finalisation is defensive: I/O and plotting failures do not crash the solver after a successful run; they are treated as best-effort.

Clarity of run semantics:

  • The run pipeline is explicitly split into preparation, execution, and finalisation phases.

  • Continuations do not implicitly reset logs or state unless explicitly requested.

MPI-safe interruption:

  • Single-rank runs attempt graceful abort and still export a final checkpoint.

  • Multi-rank runs hard-abort on Ctrl+C to avoid deadlocks from rank-skewed collectives.

type lqx:

AbstractLqxInterface

param lqx:

The model interface implementing AbstractLqxInterface.

type output_path:

str | None

param output_path:

Optional base directory for solver outputs.

type auxiliary_path:

str | None

param auxiliary_path:

Optional grouping component appended under output_path.

type clean_up:

bool | None

param clean_up:

If True, remove empty stale directories before starting.

type seed:

int | None

param seed:

Optional explicit seed controlling solver reproducibility.

return:

None.

raises ValueError:

If output_path is provided but invalid.

raises OSError:

If output directories cannot be created.

run(n_iters, *, silent_print=False, silent_plot=True, timer=False, live_monitoring=False, callbacks=<function Solver.<lambda>>, observables=None, **kwargs)

Run an optimisation/minimisation simulation.

This method is the primary optimisation entry point. The implementation is structured into three phases to keep the execution path explicit and robust:

  1. Preparation (_prepare_run()) - For fresh runs, ensures VMC driver/state are initialised. - Normalises callbacks into a list. - Optionally attaches a live monitoring callback. - Stores the observable registry and logs basic run metadata.

  2. Execution (_execute_run()) - Calls the underlying VMC driver loop with consistent parameters. - Tracks performed iteration count when available.

  3. Finalisation (_finalise_run()) - Mirrors driver state back into solver state. - Computes final expectation/constraint estimates. - Logs trailing statistics, exports a checkpoint, and plots results.

Interrupt handling

  • KeyboardInterrupt (Ctrl+C):

    • Single-rank: marks the run as aborted, writes an abort marker, and still finalises safely.

    • Multi-rank: writes a lightweight abort marker from rank-0 when possible and then hard-aborts the MPI job to prevent deadlocks.

  • Other exceptions:

    • Marks run as aborted and writes a lightweight abort marker on rank-0 when possible.

    • Re-raises the exception to keep the error visible.

type n_iters:

int

param n_iters:

Number of optimisation iterations to execute.

type silent_print:

bool

param silent_print:

If True, suppress user-facing printing during finalisation.

type silent_plot:

bool

param silent_plot:

If True, suppress interactive plot display (figures may still be saved).

type timer:

bool

param timer:

If True, enable timing/profiling integration in the driver loop.

type live_monitoring:

bool

param live_monitoring:

If True, attach a live monitoring callback for real-time feedback.

type callbacks:

Callable[[int, dict, AbstractVariationalDriver], bool]

param callbacks:

Callback or list of callbacks invoked after each iteration. Each callback must accept (iteration_index, log_dict, driver) and return a boolean. Returning False requests early termination.

type observables:

dict[str, AbstractOperator] | None

param observables:

Optional mapping of observable names to operators/observables to evaluate during optimisation.

type kwargs:

param kwargs:

Implementation-specific run options. Internal continuation mode may set _ct=True.

rtype:

None

return:

None.

raises RuntimeError:

If required solver components are not initialised and cannot be initialised.

raises Exception:

Propagates any unexpected exception raised by the driver loop or callbacks.

continue_simulation(*, state_path=None, n_iters=500, silent_print=False, silent_plot=False, timer=False, live_monitoring=False, callbacks=<function Solver.<lambda>>, observables=None, force_load_mpi=False, **kwargs)

Continue a previously started simulation by resuming from memory or from a serialised checkpoint.

This method supports two continuation sources:

  1. Resume from disk (state_path provided):

    • Ensures a template variational state/driver exists (needed for safe deserialisation).

    • Imports a full MCState from disk in an MPI-safe manner.

    • Installs the loaded state into the solver and updates the driver reference.

  2. Resume from memory (state_path is None):

    • If a driver exists, reuse it.

    • Otherwise, initialise the driver/state using current solver configuration.

Log semantics

The runtime log is preserved across continuations unless it is missing, in which case a new log is created. Observables default to the previously registered set unless explicitly overridden.

type state_path:

str | None

param state_path:

Optional path to a previously exported MCState checkpoint.

type n_iters:

int

param n_iters:

Number of additional optimisation iterations to execute.

type silent_print:

bool

param silent_print:

If True, suppress printing of final log output.

type silent_plot:

bool

param silent_plot:

If True, suppress interactive plot display.

type timer:

bool

param timer:

If True, enable timing/profiling integration in the driver.

type live_monitoring:

bool

param live_monitoring:

If True, attach live monitoring callback for this continuation run.

type callbacks:

Callable[[int, dict, AbstractVariationalDriver], bool]

param callbacks:

Callback or list of callbacks invoked after each iteration.

type observables:

dict[str, AbstractOperator] | None

param observables:

Optional observable registry for this continuation run. If None, reuse the last registered observables (if any).

type force_load_mpi:

bool

param force_load_mpi:

If True, allow best-effort reconstruction even if MPI configuration differs from the environment used when saving.

type kwargs:

param kwargs:

Implementation-specific options. May include diagonal_shift override.

rtype:

None

return:

None.

raises RuntimeError:

If continuation is requested but solver components are not initialised and cannot be initialised.

raises OSError:

If state_path is provided but the file cannot be read.

raises ValueError:

If the checkpoint is invalid or incompatible with the current environment.

export_state(*, state=None, silent=False, marker='', **kwargs)

Export (serialise) a complete Monte Carlo variational state to disk.

This method exports the full MCState (including sampler configuration/state) so that a run can be reproduced even if the checkpoint was taken before convergence.

MPI synchronisation rationale

Serialisation converts JAX arrays into host buffers. This conversion blocks until all outstanding device work is complete, which may include MPI-backed collectives issued during the last VMC step. To avoid deadlocks and ensure all ranks reach a consistent point:

  • All ranks serialise and synchronise on an MPI barrier.

  • Rank-0 writes the checkpoint to disk.

  • All ranks synchronise again before returning.

File naming

The exported filename includes:

  • the solver hash,

  • a timestamp,

  • an optional marker suffix (e.g. “FinalState”, “AbortedState”).

type state:

MCState | None

param state:

Optional explicit state to export. If None, export the solver’s current state.

type silent:

bool

param silent:

If True, suppress user-facing confirmation messages.

type marker:

str

param marker:

Optional label appended to the exported filename for easy identification.

type kwargs:

param kwargs:

Implementation-specific export options (currently unused).

rtype:

str | None

return:

The exported file path if the implementation chooses to return it; otherwise None.

raises RuntimeError:

If no variational state exists and state is None.

raises OSError:

If the checkpoint cannot be written to disk (rank-0).

import_state(state_path, *, force_load_mpi=False)

Import and deserialise a variational state from a checkpoint file.

This method loads a full serialised MCState in an MPI-safe way:

  • Rank-0 reads the checkpoint bytes from disk.

  • The raw checkpoint payload is broadcast to all ranks.

  • Each rank reconstructs a new MCState using deserialize_MCState().

Template requirement

Deserialisation may rely on a “template” MCState for shape/dtype/model reconstruction. This implementation expects that self.variational_state exists when calling this method. Continuation logic ensures a template is created before import if needed.

type state_path:

str

param state_path:

Path to a previously exported checkpoint file.

type force_load_mpi:

bool

param force_load_mpi:

If True, force best-effort reconstruction even if MPI configuration differs from the environment used when saving.

rtype:

MCState

return:

A reconstructed MCState instance.

raises OSError:

If the checkpoint file cannot be read on rank-0.

raises RuntimeError:

If a required template state is not available for deserialisation.

raises ValueError:

If the checkpoint payload is invalid or incompatible.

plot_results(*, silent_plot=False, with_inset=True, **kwargs)

Plot and save the main minimisation curve and export auxiliary run artefacts.

This method produces: - a plot of the constraint expectation value versus iteration with error bars, - an optional inset zoom into the final iterations (useful for convergence inspection), - a saved image of the underlying graph when supported by the graph object, - a serialised copy of the runtime log data to disk.

Plot content

  • The primary curve uses the runtime log series under "Constraint":

    • x-axis: iteration index

    • y-axis: mean value

    • error bars: sigma

  • If the model provides an exact diagonalisation ground energy, a horizontal reference line is drawn.

Inset behaviour

When with_inset=True, the method:

  • zooms into the last ~20 iterations,

  • sanitises data via reject_outliers() for stable y-limits,

  • draws an inset with matching content and a marked region in the main plot.

type silent_plot:

bool

param silent_plot:

If True, suppress interactive plot display (figures are still saved).

type with_inset:

bool

param with_inset:

If True, include an inset zoom of the final part of the trajectory.

type kwargs:

param kwargs:

Implementation-specific plotting options (currently unused).

rtype:

None

return:

None.

raises KeyError:

If expected log keys (e.g. "Constraint") are missing from the runtime log.

raises OSError:

If output images or log serialisation cannot be written to disk.

plot_observables(*observables, **kwargs)

Plot one or more logged observables from the runtime log.

For each requested observable key, this method attempts to read:

  • Mean series,

  • Sigma series,

and produces a plot of mean versus iteration with a shaded mean±sigma band.

Observable lookup

Observable data is retrieved from the solver runtime log using the provided keys. If an observable is missing, a ValueError is raised with a user-facing message.

type observables:

str

param observables:

One or more observable names/keys as stored in log.

type kwargs:

param kwargs:

Implementation-specific plotting options (currently unused).

rtype:

None

return:

None.

raises ValueError:

If an observable key is not found in the runtime log.

set_sampler(sampler_type, *, number_of_chains=16, number_of_sweeps=20, machine_power=2, reset_chains=True, number_of_samples=512, **kwargs)

Configure and attach a sampler to the solver.

This method builds a sampler using the neuraLQX sampler factory and stores:

  • the primary sampler used by the variational state,

  • an auxiliary sampler copy (when provided by the factory),

  • the sampler initialisation kwargs for later reconstruction/logging,

  • the number of samples requested per chain.

Gauge sampler compatibility

If the requested sampler name indicates a gauge sampler (substring “gauge”), this method enforces that the underlying Hilbert space is gauge invariant. Otherwise it raises NotImplementedError to avoid silent misuse.

Logging

Sampler metadata is exported to the solver logger. Sampler kwargs are also recorded under “Sampler Configs”, excluding keys that are already covered by the sampler metadata.

type sampler_type:

str

param sampler_type:

Human-readable sampler identifier supported by neuraLQX.

type number_of_chains:

int

param number_of_chains:

Total number of Monte Carlo chains to use.

type number_of_sweeps:

int

param number_of_sweeps:

Number of sweeps per sampling step.

type machine_power:

int

param machine_power:

Exponent used in the sampling probability distribution (backend-specific).

type reset_chains:

bool

param reset_chains:

If True, reset chains at each iteration (backend-specific).

type number_of_samples:

int

param number_of_samples:

Number of samples per chain.

type kwargs:

param kwargs:

Sampler-specific parameters (e.g. d_max, rules, probabilities, Hamiltonian, temperature ladder, etc.).

rtype:

None

return:

None.

raises NotImplementedError:

If a gauge sampler is requested for a non-gauge-invariant Hilbert space.

raises ValueError:

If sampler configuration parameters are invalid for the chosen sampler type.

set_optimizer(optimizer_type='Adam', *, use_sr, preconditioner_solver='Conjugate Gradient', preconditioner=None, learning_rate=None, scheduler_type=None, scheduler_parameters=None, diagonal_shift=None, **kwargs)

Configure and attach an optimiser and the SR preconditioner linear solver.

This method constructs:

  • an Optax optimiser via the neuraLQX Optimizer builder,

  • a linear solver backend for stochastic reconfiguration (SR) via Solvers,

  • and stores a diagonal shift value used later when constructing the SR preconditioner.

Logging

Optimiser metadata and solver metadata are exported to the solver logger under “Optimizer Configs”.

type optimizer_type:

str

param optimizer_type:

Name of the optimiser supported by neuraLQX (e.g. “Adam”).

type preconditioner_solver:

str

param preconditioner_solver:

Name of the linear solver used inside SR (e.g. “Conjugate Gradient”).

type preconditioner:

AbstractLinearPreconditioner | None

param preconditioner:

The preconditioner object used in the natural gradient descent, if not specified SR is used by default.

type learning_rate:

float | None

param learning_rate:

Optional constant learning rate. If None, a sensible default is chosen by the optimiser builder.

type scheduler_type:

str | None

param scheduler_type:

Optional learning-rate schedule identifier.

type scheduler_parameters:

dict[str, Any] | None

param scheduler_parameters:

Optional mapping of schedule parameters passed to the schedule builder.

type diagonal_shift:

float | None

param diagonal_shift:

Optional diagonal shift regularisation used by SR. If None, defaults to 0.1.

type kwargs:

param kwargs:

Additional optimiser-specific configuration parameters forwarded to the builder.

type use_sr:

bool

param use_sr:

if True, a stochastic reconfiguration (SR) preconditioner will be used

rtype:

None

return:

None.

raises ValueError:

If optimiser or schedule configuration is invalid.

raises RuntimeError:

If the optimiser or solver backend cannot be constructed.

set_network(network, *, diff_invariant=False, symmetries=None, chunk_size=None, **kwargs)

Configure and attach a neural network Ansatz to the solver.

This method:

  • records network type and attributes to the solver logger,

  • stores a copy of the original network for internal use when needed,

  • optionally wraps the network with a group/symmetry projector when running a diffeomorphism invariant simulation,

  • stores the final network module and updates internal flags.

Diffeomorphism-invariant mode

If diff_invariant=True, the solver marks the run as diffeomorphism invariant and wraps the network using wrap_model(). The wrapper applies a group-averaging/projection procedure over the provided symmetry descriptors.

type network:

Module

param network:

Neural network module (typically a subclass of flax.linen.Module).

type diff_invariant:

bool

param diff_invariant:

If True, enable diffeomorphism-invariant (group-averaged) simulation mode.

type symmetries:

Sequence[Any] | None

param symmetries:

Optional symmetry descriptors used by the projection wrapper.

type kwargs:

param kwargs:

Implementation-specific parameters (currently unused).

type chunk_size:

int | None

param chunk_size:

If specified, the expectation values and gradients are computed in a chunked manner where the largest chunk has size chunk_size

rtype:

None

return:

None.

raises ValueError:

If symmetry configuration is inconsistent with the graph/model.

raises RuntimeError:

If the network wrapper cannot be constructed.

expect(operator, *, n_samples=None, n_chains=None, use_same_variables=True, use_same_sampler=True, **kwargs)

Compute the expectation value of an operator or list of operators.

This method provides explicit sampling semantics to avoid accidental misuse.

Supported modes

  1. Alias mode (default): If only operator is provided, this is a direct alias to variational_state.expect(operator).

  2. Alias mode with temporary sample-count override: If n_samples is provided and n_chains is not provided, this method temporarily overrides variational_state.n_samples, warns the user that samples will be regenerated, evaluates the expectation, and restores the original sample count.

Unsupported mode (planned)

If n_chains is provided, the intended behaviour is to build a temporary state with a modified sampler and draw fresh samples. This mode is not implemented and currently raises NotImplementedError.

Ignored arguments

In alias modes, arguments that would only apply to “fresh sampling” (such as use_same_variables, use_same_sampler, and additional kwargs) are ignored with a warning.

type operator:

Union[list[Union[AbstractOperator, AbstractObservable]], AbstractOperator, AbstractObservable]

param operator:

Operator/observable or a list of operators/observables to evaluate.

type n_samples:

int | None

param n_samples:

Optional number of samples per chain (alias mode override).

type n_chains:

int | None

param n_chains:

Optional number of chains (fresh sampling request; not implemented).

type use_same_variables:

bool

param use_same_variables:

Intended for fresh sampling mode; ignored in alias modes.

type use_same_sampler:

bool

param use_same_sampler:

Intended for fresh sampling mode; ignored in alias modes.

type kwargs:

param kwargs:

Additional options; ignored in alias modes.

rtype:

Stats

return:

A Stats-like object containing the estimated expectation value(s).

raises RuntimeError:

If the variational state is not initialised.

raises NotImplementedError:

If n_chains is provided (fresh sampling mode not implemented).

initialize_vmc(*args, **kwargs)

Initialise the variational state and VMC driver.

This method constructs the objects required to execute optimisation:

  1. Validates configuration: Requires that sampler, optimiser, and network have been configured via:

  2. Builds the Monte Carlo variational state: Constructs MCState using the configured sampler and network and stores it in the solver. It also records the number of trainable parameters and the ratio of parameter count to Hilbert dimension (when estimable).

  3. Builds the VMC driver: Constructs VMC using the model constraint as the optimisation target, the configured optimiser, and an SR preconditioner:

    • diagonal shift taken from diagonal_shift,

    • holomorphic flag inferred via NetKet’s holomorphic diagnostic,

    • linear solver backend taken from preconditioner_solver.

Parameters:
  • args – Implementation-specific initialisation options (currently unused).

  • kwargs – Implementation-specific initialisation options (may be ignored by this implementation).

Return type:

None

Returns:

None.

Raises:
  • PermissionError – If sampler, optimiser, or network have not been configured.

  • RuntimeError – If the variational state or driver cannot be constructed.