Single-trunk multi-head variational Monte Carlo (ST-MH)¶
This page documents the experimental single-trunk multi-head (ST-MH) variational Monte Carlo machinery shipped with neuraLQX for joint optimisation of multiple states using a shared parameter set.
Where MultiSolver (Multi-Trunk style in this codebase) optimises N independent networks
packaged in a MultiMCState with a coupling penalty, the Single-Trunk implementation optimises one shared Flax model with:
a shared trunk (feature extractor),
\(K\) lightweight heads (one per target state),
one shared optimiser state,
one shared parameter update per iteration.
This is useful when the target manifold is expected to share internal structure, for example
degenerate or nearly degenerate eigenstates,
multiple physical representatives in the same constrained sector,
low-energy manifolds where a shared representation improves sample efficiency and stability.
In neuraLQX this is exposed through:
STMultiSolver(drop-in solver interface),SingleTrunkMultiHeadVMC(shared-gradient VMC driver),STMultiMCState(shared-parameter state container),SingleTrunkMultiHeadLogPsiand head views (Flax wrapper layer).
The ST-MH stack is intentionally designed to keep ordinary NetKet/neuraLQX MCState objects in the loop so sampling and estimator behaviour remain familiar, while changing the parameter geometry and update rule to match the shared-parameter ansatz.
What ST-MH VMC does¶
Assume you want \(K\) variational states \(\{|\psi_k(\Theta)\rangle\}_{k=1}^{K}\) where all states depend on the same parameter set \(\Theta\) through a shared trunk and head-specific readout.
The ST-MH driver minimises the objective
with
and pairwise fidelity penalty
Here
\(w_k\) are user-provided
energy_weightsthat are internally normalised to sum to one\(\lambda_{\mathrm{ortho}}\) is the orthogonality penalty strength
\(\hat{C}_k\) is either a shared operator or a per-head operator if a list is supplied
Interpretation:
The weighted energy term encourages each head to minimise its target operator.
The fidelity term discourages head collapse onto the same state.
The optimisation variable is one shared parameter pytree.
This is the central difference from MT-MH. In MT-MH each state has its own parameter set \(\theta_i\). In ST-MH every gradient contribution must be accumulated into a single \(\nabla_\Theta C\).
Architecture overview: wrapper, state, driver, solver¶
The ST-MH implementation is split into four layers.
Flax wrapper layer Helpers in
neuralqx.experimental.nn.projectors.stmhwhich convert an arbitrary feature-producing trunk into a multi-head log-wavefunction model.Head-view compatibility models Expose one selected head as a standard scalar-output model so each head can be handed to
MCState.Shared-parameter state container Groups the head
MCStateobjects but exposes a singleparameterspytree to the driver.ST-MH VMC driver Aggregates all energy and orthogonality gradients into one shared update direction.
ST-MH solver Provides the standard neuraLQX user workflow and checkpointing interface.
A key design goal is minimal disruption of the existing single-state and MT-MH code path.
Sampling still happens through ordinary MCState objects and the new logic is concentrated in
the shared wrapper, shared container, and shared-gradient driver.
ST-MH ansatz wrapper: SingleTrunkMultiHeadLogPsi and head views¶
The wrapper functionalities in neuralqx.experimental.nn.projectors.stmh constructs an ST-MH ansatz from a user-defined trunk network.
High-level form¶
Let \(f_\phi(x)\) denote the trunk output (features for configuration \(x\)). The wrapper applies one affine readout per head. For the default complex output case,
where \(a_k\) and \(b_k\) are implemented as dense layers. All heads share the same trunk features and differ only in the final linear maps.
The wrapper can also produce real outputs if complex_logpsi=False.
What the trunk should return¶
The preferred convention is a batch-first feature tensor such as
(batch, hidden_dim)or any batch-first shape
(batch, ...)that can be flattened to(batch, F)
The wrapper accepts several output styles and can extract features from
a direct tensor return
a dict return (for example
{"features": ...})a tuple/list return (first element by default)
This allows you to reuse existing Flax modules without rewriting them.
If the trunk returns a scalar per sample with shape (batch,) the wrapper lifts it to
(batch, 1). This runs correctly but yields only a one-dimensional shared feature space, which
is usually too restrictive for a useful ST-MH ansatz.
Public classes and helpers¶
The wrapper module in neuralqx.experimental.nn.projectors.stmh currently provides:
SingleTrunkMultiHeadLogPsiMain ST-MH wrapper that outputs all heads or a selected head.
STMHHeadViewCompatibility wrapper that exposes one head as a standard scalar-output Flax module. This is the object you pass into
MCState.STMHAllHeadsViewThin wrapper that always returns all heads, shape
(batch, K).wrap_trunk_as_stmh(...)Convenience constructor around
SingleTrunkMultiHeadLogPsi.make_stmh_head_models(...)Returns a list of
STMHHeadViewobjects, one per head.
SingleTrunkMultiHeadLogPsi constructor parameters¶
The wrapper is intentionally generic. The most important parameters are:
trunkAny Flax module that maps basis configurations to features.
n_headsNumber of heads \(K\).
latent_dimOptional projection width. If the extracted feature width differs from
latent_dimthe wrapper inserts a learned projection layertrunk_projbefore the heads.trunk_outputHow the trunk return value is interpreted. Supported modes are
"auto""features""dict""tuple"
complex_logpsiIf
Truethe wrapper uses separate real and imaginary head Dense layers and returns complex log-amplitudes.flatten_featuresIf
Trueall trailing dimensions after the batch axis are flattened.features_keyandtuple_indexSelection controls for dict or tuple trunk outputs.
dtypeandparam_dtypeForwarded to the internal Dense layers.
Call behaviour and output shapes¶
The wrapper call signature supports two important keyword arguments:
headreturn_features
When head=None (default), the output shape is (batch, K) and contains all heads.
When head=i the wrapper returns the selected head only with shape (batch,). This is the
mode used by STMHHeadView and by the solver compatibility wrappers.
When return_features=True the wrapper returns (out, feats) where feats is the extracted
and optionally projected feature matrix used by the heads. This is useful for debugging and for
confirming that the trunk output shape is what you expect.
Examples¶
Wrap a trunk directly:
stmh = SingleTrunkMultiHeadLogPsi(
trunk=MyTrunk(...),
n_heads=4,
complex_logpsi=True,
)
y_all = stmh.apply(params, sigma) # shape (batch, 4)
y_0 = stmh.apply(params, sigma, head=0) # shape (batch,)
Create head-view models for MCState:
head_models = make_stmh_head_models(stmh, n_heads=4)
# each head_models[i] is scalar-output and can be passed to MCState
s0 = MCState(sampler, head_models[0], n_samples=2048)
s1 = MCState(sampler, head_models[1], n_samples=2048)
Notes on diffeomorphism-invariant wrapping¶
For ST-MH, the recommended place to apply diffeomorphism projection in solver workflows is at the
head-view level during STMultiSolver.initialize_vmc. This ensures each head-specific scalar
model is projected exactly as in standard single-state usage while still sharing the same underlying
parameter tree.
machine_pow and fidelity interpretation¶
The ST-MH orthogonality penalty is implemented using a joint-sampling overlap estimator that has a clean fidelity interpretation in the usual quantum setting only under a specific condition.
Recommended and default regime¶
Use machine_pow = 2 for all head samplers.
When every head samples from a distribution proportional to \(|\psi|^2\), the estimator used by the orthogonality kernel matches the normalized squared fidelity and the penalty is directly interpretable as orthogonality enforcement in the standard quantum sense.
What the driver enforces¶
By default SingleTrunkMultiHeadVMC sets enforce_machine_pow_2=True.
This means the driver checks all head samplers and raises a ValueError if any head has a
different machine_pow. This is a safety mechanism for correctness of interpretation.
What happens if you disable the check¶
If enforce_machine_pow_2=False the code still runs. The pairwise penalty becomes a more general
overlap-like surrogate defined by the chosen sampling powers.
This can still act as a useful repulsion regulariser between heads, but
its scale changes with the sampler exponents
its interpretation as normalized fidelity is lost
mixed powers across heads are harder to reason about
The driver clips logged pairwise values into [0,1] for presentation only. That clipping is not a
proof that the estimator remains a true fidelity under nonstandard machine_pow values.
Practical recommendation¶
Unless you intentionally study alternative overlap penalties, keep machine_pow = 2 for every head
and leave enforce_machine_pow_2=True.
Using SR / QGT preconditioning in ST-MH¶
This is one of the most important conceptual differences between MT-MH and the current ST-MH driver.
Euclidean gradient descent path (exact)¶
If the preconditioner is identity_preconditioner then the ST-MH driver performs the exact
Euclidean gradient descent direction on the sampled objective defined by the weighted energy and
orthogonality terms.
This is the cleanest baseline for validating a new ST-MH ansatz.
SR / QGT path (current implementation is an approximation)¶
If you pass an SR/QGT preconditioner, the driver applies it using the geometry of one selected head state only:
ref_state = self.states[preconditioner_state_index]
dp = preconditioner(ref_state, global_grad, step)
This is a practical approximation. It often works well enough to gain SR-like stabilisation, but it is not the exact natural-gradient step for the full ST-MH objective.
Why it is approximate¶
The true ST-MH objective couples all heads through both
per-head energy terms
pairwise orthogonality terms
A strict natural-gradient treatment would require a geometry that represents the tangent space of the full shared multi-head objective, including the way all heads depend on the same trunk features.
The current implementation reuses the existing single-head state geometry for convenience and compatibility. This captures the parameter structure of the shared model and one head view, but it does not build a full combined ST-MH QGT.
Future extension direction¶
A dedicated ST-MH preconditioner could construct a combined geometry from multiple heads and possibly from weighted objective terms. The current driver is designed so such a preconditioner can be plugged in later without changing the outer solver interface.
Usage examples (STMultiSolver-first)¶
Minimal ST-MH setup from a trunk¶
import flax.linen as nn
import jax.numpy as jnp
from neuralqx.experimental.solver import STMultiSolver
from neuralqx.experimental.nn.projectors.stmh import SingleTrunkMultiHeadLogPsi
class MyTrunk(nn.Module):
@nn.compact
def __call__(self, sigma):
x = sigma.astype(jnp.float32)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(64)(x)
x = nn.tanh(x)
return x # features, shape (batch, 64)
base_model = SingleTrunkMultiHeadLogPsi(
trunk=MyTrunk(),
n_heads=3,
complex_logpsi=True,
)
solver = STMultiSolver(lqx, output_path="runs/stmh_demo", seed=0)
solver.set_sampler(...)
solver.set_optimizer(...)
solver.set_network(base_model, lambda_ortho=1.0)
solver.initialize_vmc()
solver.run(500)
Measure all heads and one head¶
all_stats = solver.expect(lqx.constraint) # list[Stats]
s0 = solver.expect(lqx.constraint, state_idx=0) # Stats
print(all_stats[0], all_stats[1], all_stats[2])
print("head 0 mean =", float(s0.Mean))
ST-MH with explicit per-head views¶
If your base model exposes heads in a nonstandard way, you can provide explicit head models to
set_network and still use the ST-MH solver:
head_models = [Head0View(base_model), Head1View(base_model), Head2View(base_model)]
solver.set_network(
base_model,
head_models=head_models,
n_heads=3,
lambda_ortho=0.5,
)
Head-specific operators with energy weights¶
The driver supports one operator per head and weighted energy aggregation:
solver.initialize_vmc(
energy_weights=[0.6, 0.2, 0.2],
preconditioner_state_index=0,
enforce_machine_pow_2=True,
)
If you construct the driver manually, you can also pass a list of operators directly to
SingleTrunkMultiHeadVMC(...).
Diffeomorphism-invariant ST-MH run¶
solver.set_network(
base_model,
diff_invariant=True,
symmetries=symmetries,
lambda_ortho=0.5,
)
solver.initialize_vmc()
solver.run(500)
The solver applies the group projector to each head view before constructing the MCState objects.
The orthogonality penalty then separates the projected heads.
Checkpoint export and import¶
solver.export_state(marker="after_500")
# later
solver2 = STMultiSolver(lqx, output_path="runs/stmh_resume", seed=123)
solver2.set_sampler(...)
solver2.set_optimizer(...)
solver2.set_network(base_model, n_heads=3)
solver2.initialize_vmc()
solver2.import_state("...SerialisedSTMHState_....mpack")
Practical tuning notes for ST-MH¶
Feature richness of the trunk¶
The ST-MH benefit comes from a useful shared representation. A trunk that returns only a scalar feature severely limits what the heads can express.
In practice, design the trunk to return a moderate feature width and let the heads remain lightweight.
Sampler alignment across heads¶
Pairwise fidelity estimates use sample batches from each head state. If one head mixes poorly or uses very different sampling settings, the orthogonality term becomes noisy and can dominate the update.
It is usually best to keep sampler settings aligned across heads unless there is a deliberate reason to diverge.
Interpreting parameter counts¶
The solver logs both shared and nominal replicated parameter counts.
Shared count approximates the true optimization dimension.
Nominal replicated count is a comparison number that is closer to what an MT-MH run would spend if each head were trained independently.
This is a useful diagnostic when comparing ST-MH and MT-MH runs at similar compute budgets.
Migration notes: MT-MH to ST-MH¶
If you already use MultiSolver, the main conceptual migration
steps are:
Replace a list of independent networks with one shared ST-MH base model.
Ensure the base model can expose a selected head with
head=ior provide explicit head views.Use
STMultiSolverinstead ofMultiSolver.Interpret parameter counts and SR behaviour with the ST-MH shared-parameter geometry in mind.
What stays the same¶
Sampling still uses ordinary
MCStateobjects.The orthogonality penalty uses the same joint-sampling estimator family.
The solver workflow remains
set_sampler -> set_optimizer -> set_network -> initialize_vmc -> run.
What changes mathematically¶
The optimization variable becomes one shared parameter pytree.
Gradients from all heads and all pair penalties are summed before preconditioning.
SR is applied once on a reference head geometry in the current implementation.
This is the correct update structure for a shared trunk plus multiple heads ansatz.
Troubleshooting checklist¶
Shape errors in the wrapper¶
Symptoms:
Flax Dense shape mismatch in the head layers
unexpected output rank from the base model
Checks:
confirm the trunk returns batch-first features
set
return_features=Truetemporarily and inspectfeats.shapeuse
flatten_features=Trueif the trunk returns(batch, ..., ...)
TypeError when building head views¶
Symptoms:
solver error stating the base model does not support
head=
Fixes:
use
SingleTrunkMultiHeadLogPsior a compatible base modelor pass explicit
head_models=[...]toSTMultiSolver.set_network(...)
ValueError about machine_pow¶
Symptoms:
driver raises because some head has
machine_pow != 2
Fixes:
align all samplers to
machine_pow=2for standard fidelity semanticsonly disable
enforce_machine_pow_2if you intentionally want a surrogate overlap penalty
Unexpected head divergence or inconsistent behaviour¶
Checks:
verify the shared state is
STMultiMCStateand notMultiMCStateconfirm
reset()is being called through the driver each stepavoid manual mutation of per-head
MCState.parametersoutside the container
SR instability¶
Checks:
compare against identity preconditioning first
try a different
preconditioner_state_indexinspect per-head sample quality and fidelity noise
reduce
lambda_orthotemporarily to decouple sources of instability