neuralqx.experimental.nn.projectors.stmh package

class SingleTrunkMultiHeadLogPsi(trunk, n_heads, latent_dim=None, trunk_output='auto', complex_logpsi=True, flatten_features=True, features_key='features', tuple_index=0, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

Generic single-trunk multi-head wrapper for Flax models.

The wrapped trunk is expected to return features per configuration. The preferred convention NetKet’s, which is a tensor whose leading axis is the batch axis, for example (batch, h). The wrapper flattens any trailing feature dimensions and applies lightweight linear heads to produce K log-amplitudes, one per target state/head.

By default, the wrapper outputs complex log-amplitudes

\[\log\psi_k(x) = a_k(f(x)) + i\, b_k(f(x)),\]

where a_k and b_k are independent affine maps implemented as two Dense layers.

Notes

  • For NetKet/MCState compatibility, use STMHHeadView to expose a single head with output shape (batch,).

  • If your existing model currently returns a scalar (batch,) log-amplitude, you should ideally factor it into trunk + head. This wrapper can lift a scalar trunk output to a one-dimensional feature, but that yields only a 1D shared feature space.

class STMHHeadView(base, head, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

Exposes a single head from SingleTrunkMultiHeadLogPsi as a standard scalar-output model.

This is the compatibility wrapper you should hand to MCState so that each head behaves like a normal NQS model returning (batch,) log-amplitudes.

class STMHAllHeadsView(base, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

Thin alias wrapper returning all heads explicitly (shape (batch, K)).

wrap_trunk_as_stmh(trunk, n_heads, **kwargs)

Convenience constructor for a single-state multi-head Ansatz. You should provide a trunk-compatible Flax based model at construction time and specify the number of desired heads (i.e. number of orthogonal states).

Return type:

SingleTrunkMultiHeadLogPsi

Examples

>>> stmh = wrap_trunk_as_stmh(MyTrunk(...), n_heads=2)
>>> head0_model = STMHHeadView(stmh, 0)
>>> head1_model = STMHHeadView(stmh, 1)
make_stmh_head_models(base_stmh_model, n_heads)

Create n_heads scalar-output Flax models selecting each head of base_stmh_model.

Return type:

list[STMHHeadView]