neuralqx.experimental.solver.stmh_solver module

serialize_STMultiMCState(vstate)

Serialize a STMultiMCState into a plain Python dictionary.

Each head MCState is serialized using serialize_MCState. Parameters are shared, but we intentionally keep the per-head payload format because it is robust and backward-compatible.

Return type:

dict

deserialize_STMultiMCState(template, state_dict, *, force_load_mpi=False)

Deserialize a shared-parameter multi-head state from a checkpoint dictionary.

Parameters:
  • template (STMultiMCState) – Existing STMultiMCState used to reconstruct each contained MCState.

  • state_dict (dict) – Raw dict loaded from disk.

  • force_load_mpi (bool) – Forwarded to deserialize_MCState.

Return type:

STMultiMCState

class STMultiSolver(lqx, output_path=None, auxiliary_path=None, *, clean_up=False, seed=None)

Bases: Solver

Solver for the single-trunk multi-head (ST-MH) variational ansatz.

Compared to MultiSolver (MT-MH / independent networks), this solver builds one shared multi-head Flax model and exposes one MCState per head using head-selector wrappers. The variational state is a STMultiMCState, and optimization is performed by SingleTrunkMultiHeadVMC, which aggregates energy and orthogonality gradients into a single shared parameter update.

property lambda_ortho: float
property n_heads: int
set_network(network, *, n_heads=None, head_models=None, diff_invariant=False, symmetries=None, lambda_ortho=None, canonical_state=0, sync_model_state=False, **kwargs)

Configure a shared ST-MH network and (optionally) explicit per-head view models.

Parameters:
  • network (Module) – Shared ST-MH base model (typically your wrapper that supports head=...).

  • n_heads (Optional[int]) – Number of heads. If omitted, we try to infer it from network.n_heads.

  • head_models (Optional[Sequence[Module]]) – Optional explicit scalar-output head models, one per head. If provided, these are used directly and network is treated as the shared base model only for logging.

  • symmetries (diff_invariant,) – Same semantics as in MultiSolver. For ST-MH, diffeo wrapping is applied to the head models during initialize_vmc.

  • lambda_ortho (Optional[float]) – Orthogonality/fidelity penalty strength.

  • canonical_state (int) – Forwarded to STMultiMCState.

  • sync_model_state (bool) – Forwarded to STMultiMCState.

Return type:

None

expect(operator, *, state_idx=None, n_samples=None, n_chains=None, use_same_variables=True, use_same_sampler=True, **kwargs)

Multi-head aware expectation helper mirroring MultiSolver.expect.

Return type:

Union[Stats, List[Stats]]

initialize_vmc(*args, **kwargs)

Initialise the shared-parameter ST-MH variational state and driver.

Builds one MCState per head (head-selector view), then wraps them in STMultiMCState and constructs SingleTrunkMultiHeadVMC.

Return type:

None

export_state(*, state=None, silent=False, marker='', **kwargs)

Export a shared-parameter ST-MH checkpoint to disk.

Return type:

Optional[str]

import_state(state_path, *, force_load_mpi=False)

Import a shared-parameter ST-MH checkpoint from disk.

Return type:

STMultiMCState

plot_results(*, silent_plot=False, with_inset=True, **kwargs)

Plot and save optimization curves.

This mirrors MultiSolver.plot_results but labels the per-state curves as heads.

Return type:

None