neuralqx.experimental.vqs.mc.mc_state.stmh_state module

class STMultiMCState(states, *, canonical_state=0, sync_model_state=False)

Bases: object

A Container for multiple head-specific MCState objects that share one parameter pytree.

This class is intentionally small and driver-oriented. It assumes you already created one MCState per head (typically from STMHHeadView models) with identical parameter tree structure, and then synchronizes their parameters so they act as different views of the same ST-MH model.

The current MultiMCState returns a list of parameter pytrees (one per state), which is exactly right for MT-MH (independent networks) but is not the ST-MH (shared trunk + heads in one joint parameter set) Ansatz. Here we expose a single parameter pytree to the base VMC driver so optimizer state and updates are computed only once.

Parameters:
  • states (list[MCState]) – List of head-specific MCState objects. They must share the same Hilbert space and the same parameter-tree structure.

  • canonical_state (int) – Index of the state whose parameters are treated as the canonical source of truth.

  • sync_model_state (bool) – If True, broadcast_from_canonical() also copies model_state from the canonical state to all other states. Keep this False unless you really need mutable model-state collections and understand the implications.

property parameters
property hilbert
property n_states: int
property canonical_state: int
property model
property sampler
property n_samples
property n_samples_per_rank
property chain_length
property n_discard_per_chain
property samples
to_array(normalize=True)
broadcast_from_canonical()
Return type:

None

reset()
Return type:

None

sample(**kwargs)
expect(O)
make_shared_stmh_state(head_states, *, canonical_state=0, sync_model_state=False)

Build a shared-parameter ST-MH container from already-created head MCState objects.

Return type:

STMultiMCState

Usage:
  1. Build one ST-MH base Flax model and K head views using STMHHeadView.

  2. Create K MCState objects exactly as you do today (one per head view).

  3. Call make_shared_stmh_state(head_states).

  4. Pass the result to SingleTrunkMultiHeadVMC.