neuralqx.experimental.vqs.mc.mc_state.stmh_state module¶
- class STMultiMCState(states, *, canonical_state=0, sync_model_state=False)¶
Bases:
objectA Container for multiple head-specific
MCStateobjects that share one parameter pytree.This class is intentionally small and driver-oriented. It assumes you already created one
MCStateper head (typically fromSTMHHeadViewmodels) with identical parameter tree structure, and then synchronizes their parameters so they act as different views of the same ST-MH model.The current
MultiMCStatereturns a list of parameter pytrees (one per state), which is exactly right for MT-MH (independent networks) but is not the ST-MH (shared trunk + heads in one joint parameter set) Ansatz. Here we expose a single parameter pytree to the base VMC driver so optimizer state and updates are computed only once.- Parameters:
states (
list[MCState]) – List of head-specificMCStateobjects. They must share the same Hilbert space and the same parameter-tree structure.canonical_state (
int) – Index of the state whose parameters are treated as the canonical source of truth.sync_model_state (
bool) – IfTrue,broadcast_from_canonical()also copiesmodel_statefrom the canonical state to all other states. Keep thisFalseunless you really need mutable model-state collections and understand the implications.
- property parameters¶
- property hilbert¶
- property model¶
- property sampler¶
- property n_samples¶
- property n_samples_per_rank¶
- property chain_length¶
- property n_discard_per_chain¶
- property samples¶
- to_array(normalize=True)¶
- sample(**kwargs)¶
- expect(O)¶
Build a shared-parameter ST-MH container from already-created head
MCStateobjects.- Return type:
- Usage:
Build one ST-MH base Flax model and
Khead views usingSTMHHeadView.Create
KMCState objects exactly as you do today (one per head view).Call
make_shared_stmh_state(head_states).Pass the result to
SingleTrunkMultiHeadVMC.