neuralqx.experimental.nn.projectors.stmh package¶
- class SingleTrunkMultiHeadLogPsi(trunk, n_heads, latent_dim=None, trunk_output='auto', complex_logpsi=True, flatten_features=True, features_key='features', tuple_index=0, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleGeneric single-trunk multi-head wrapper for Flax models.
The wrapped
trunkis expected to return features per configuration. The preferred convention NetKet’s, which is a tensor whose leading axis is the batch axis, for example(batch, h). The wrapper flattens any trailing feature dimensions and applies lightweight linear heads to produceKlog-amplitudes, one per target state/head.By default, the wrapper outputs complex log-amplitudes
\[\log\psi_k(x) = a_k(f(x)) + i\, b_k(f(x)),\]where
a_kandb_kare independent affine maps implemented as two Dense layers.Notes
For NetKet/MCState compatibility, use
STMHHeadViewto expose a single head with output shape(batch,).If your existing model currently returns a scalar
(batch,)log-amplitude, you should ideally factor it intotrunk+ head. This wrapper can lift a scalar trunk output to a one-dimensional feature, but that yields only a 1D shared feature space.
- class STMHHeadView(base, head, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleExposes a single head from
SingleTrunkMultiHeadLogPsias a standard scalar-output model.This is the compatibility wrapper you should hand to
MCStateso that each head behaves like a normal NQS model returning(batch,)log-amplitudes.
- class STMHAllHeadsView(base, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleThin alias wrapper returning all heads explicitly (shape
(batch, K)).
- wrap_trunk_as_stmh(trunk, n_heads, **kwargs)¶
Convenience constructor for a single-state multi-head Ansatz. You should provide a trunk-compatible Flax based model at construction time and specify the number of desired heads (i.e. number of orthogonal states).
- Return type:
Examples
>>> stmh = wrap_trunk_as_stmh(MyTrunk(...), n_heads=2) >>> head0_model = STMHHeadView(stmh, 0) >>> head1_model = STMHHeadView(stmh, 1)
- make_stmh_head_models(base_stmh_model, n_heads)¶
Create
n_headsscalar-output Flax models selecting each head ofbase_stmh_model.- Return type: