neuralqx.experimental.solver.stmh_solver module¶
- serialize_STMultiMCState(vstate)¶
Serialize a
STMultiMCStateinto a plain Python dictionary.Each head
MCStateis serialized usingserialize_MCState. Parameters are shared, but we intentionally keep the per-head payload format because it is robust and backward-compatible.- Return type:
- deserialize_STMultiMCState(template, state_dict, *, force_load_mpi=False)¶
Deserialize a shared-parameter multi-head state from a checkpoint dictionary.
- Parameters:
template (
STMultiMCState) – ExistingSTMultiMCStateused to reconstruct each containedMCState.state_dict (
dict) – Raw dict loaded from disk.force_load_mpi (
bool) – Forwarded todeserialize_MCState.
- Return type:
- class STMultiSolver(lqx, output_path=None, auxiliary_path=None, *, clean_up=False, seed=None)¶
Bases:
SolverSolver for the single-trunk multi-head (ST-MH) variational ansatz.
Compared to
MultiSolver(MT-MH / independent networks), this solver builds one shared multi-head Flax model and exposes oneMCStateper head using head-selector wrappers. The variational state is aSTMultiMCState, and optimization is performed bySingleTrunkMultiHeadVMC, which aggregates energy and orthogonality gradients into a single shared parameter update.- set_network(network, *, n_heads=None, head_models=None, diff_invariant=False, symmetries=None, lambda_ortho=None, canonical_state=0, sync_model_state=False, **kwargs)¶
Configure a shared ST-MH network and (optionally) explicit per-head view models.
- Parameters:
network (
Module) – Shared ST-MH base model (typically your wrapper that supportshead=...).n_heads (
Optional[int]) – Number of heads. If omitted, we try to infer it fromnetwork.n_heads.head_models (
Optional[Sequence[Module]]) – Optional explicit scalar-output head models, one per head. If provided, these are used directly andnetworkis treated as the shared base model only for logging.symmetries (diff_invariant,) – Same semantics as in
MultiSolver. For ST-MH, diffeo wrapping is applied to the head models duringinitialize_vmc.lambda_ortho (
Optional[float]) – Orthogonality/fidelity penalty strength.canonical_state (
int) – Forwarded toSTMultiMCState.sync_model_state (
bool) – Forwarded toSTMultiMCState.
- Return type:
- expect(operator, *, state_idx=None, n_samples=None, n_chains=None, use_same_variables=True, use_same_sampler=True, **kwargs)¶
Multi-head aware expectation helper mirroring
MultiSolver.expect.- Return type:
Union[Stats,List[Stats]]
- initialize_vmc(*args, **kwargs)¶
Initialise the shared-parameter ST-MH variational state and driver.
Builds one
MCStateper head (head-selector view), then wraps them inSTMultiMCStateand constructsSingleTrunkMultiHeadVMC.- Return type:
- export_state(*, state=None, silent=False, marker='', **kwargs)¶
Export a shared-parameter ST-MH checkpoint to disk.
- import_state(state_path, *, force_load_mpi=False)¶
Import a shared-parameter ST-MH checkpoint from disk.
- Return type: