neuralqx.nn.models package¶
This module contains the implementation of different neural networks used in the simulations
- class RevNet(number_of_cnn_features, number_of_cnn_blocks, lqx, dropout_rate=0.2, free_idxs=None, dtype=<class 'jax.numpy.float64'>, bn_use_running_average=False, use_visible_bias=True, conv_kernel_init=None, conv_bias_init=0.1, bn_scale_init=None, bn_bias_init=0.1, se_fc1_kernel_init=None, se_fc1_bias_init=0.1, se_fc2_kernel_init=None, se_fc2_bias_init=0.1, eval_dense_kernel_init=None, eval_dense_bias_init=0.1, visible_bias_init=0.8, rng_key=None, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleBelow is the same RevNet architecture with fully tunable normalisation and initialisers.
What differs here from the above implementaion is: - bn_use_running_average is configurable (True/False) (CURRENTLY NO BN IMPLEMENTED) - per-layer initialiser knobs, each accepting:
a callable initialiser (Flax-style),
OR a float/stddev (wraps to nn.initializers.normal(stddev)),
OR None to use a sensible default.
complex-aware initialisation: if dtype is complex, weights are built from two real draws to keep magnitude/phase variance at init
activations are now log_cosh, instead of hard_silu
- RNG usage:
- we do NOT pass rng inside the module, use Flax collections:
params = model.init({‘params’: key}, x)
The old rng_key field is deprecated and ignored
- Defaults chosen to introduce gentle variance without saturation:
Conv/Dense kernels: normal(stddev=0.05)
Biases: zeros
BatchNorm scale: ones; bias: zeros
Initializer knobs (per layer)¶
- For each spec below, you can pass:
callable initializer (Flax style),
or a float/int (interpreted as stddev for normal init),
or None for default described here.
conv_kernel_init : InitLike, default normal(0.05) conv_bias_init : InitLike, default zeros
bn_scale_init : InitLike, default ones bn_bias_init : InitLike, default zeros
se_fc1_kernel_init : InitLike, default normal(0.05) se_fc1_bias_init : InitLike, default zeros se_fc2_kernel_init : InitLike, default normal(0.05) se_fc2_bias_init : InitLike, default zeros
eval_dense_kernel_init : InitLike, default normal(0.05) eval_dense_bias_init : InitLike, default zeros
Notes
RNG: initialize with model.init({‘params’: key}, x). Different keys give different parameter samples. The deprecated rng_key field is ignored.
If you previously depended on BatchNorm with use_running_average=True during training, that can suppress variance. Toggle this via the knob.
- class RevNetAuto(number_of_cnn_features, rng_key, lqx, dropout_rate=0.2, free_idxs=None, dtype=<class 'jax.numpy.float64'>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleThe RevNet network implementation in flax
- class ResNet(num_blocks=2, block_features=60, main_conv=60, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleA ResNet like neural network composed of several residual blocks with skip connections
- class RevNetComplex(number_of_cnn_features, number_of_cnn_blocks, lqx, rng_key, dropout_rate=0.0, free_idxs=None, dtype=<class 'jax.numpy.float64'>, use_layer_norm=False, _default_R_activation=<PjitFunction of <function hard_silu>>, reduction_ratio=2, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleRevNet implementation supporting real and complex dtypes
- make_cfvtnet_kwargs(graph, H)¶
Construct graph/hamiltonian dependent kwargs for CFVTNet.
- class CFVTNet(n_edges, n_vertices, gauge_dim, cutoff, q_min, q_max, q_step, edge_to_index, loops_edges_typed=None, loops_indices=None, triplets_by_vertex=None, triplet_signs_by_vertex=None, N_F=1, freq_mode='harmonic', freq_log_base=2.0, tie_copies=False, use_vertex_odd=True, use_vertex_even=True, use_phase=True, use_edge_bias=False, use_loop_bias=False, param_dtype=<class 'jax.numpy.float64'>, output_dtype=<class 'jax.numpy.complex128'>, l2_fourier=0.0, l2_vertex=0.0, l2_edge=0.0, l2_loop=0.0, init_std_A=0.1, init_std_B=0.1, init_std_A_phase=0.1, init_std_B_phase=0.1, init_std_alpha=0.1, init_std_lambda=0.1, init_std_beta=0.1, init_std_alpha_phase=0.1, init_std_visible_bias=0.4, init_std_visible_bias_phase=0.5, init_std_loop_bias=0.0, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleA fast gauge-equivariant ansatz for lattice gauge states.
This module implements a composite ansatz for (log-) wavefunction amplitudes on a tensor-product Hilbert space built over K gauge levels (dimensions, e.g. for U(1)^3 this would be 3) of a base graph. Each level corresponds to a block of n_edges oriented edges, the total visible width is n_edges * gauge_dimensions.
The architecture fuses three ingredients:
- Loop Fourier features
we extract loop charges q_{ℓ,k} by summing ±σ_e over the oriented edges e that form each minimal loop ℓ, for each level k
we then compute periodic features cos(ω_j q_{ℓ,k}), sin(ω_j q_{ℓ,k}) where ω_j are harmonics over the cyclic group Z_M (M = 2 * cutoff + 1)
trainable weights A, B (and optionally A_phase, B_phase for complex phase) linearly combine these features into a scalar contribution to log|ψ| (and arg ψ)
- Local, volume-like triplet term (optional, requires K = 3)
at each vertex, we may provide oriented triplets of incident edges (e1, e2, e)
for each triplet, we form a pseudoscalar Tv := ε · (σ_{e1} · (σ_{e2} × σ_{e3})), this couples the three gauge levels as a triple product
odd (tanh with learned scale λ) and even (quadratic penalty) parts contribute to the amplitude, an optional linear part contributes to the phase
- Bias terms (optional)
per-edge linear bias (RBM style, Σ_e b_e σ_e) adds directly to log|ψ|, with an optional phase counterpart
a per-loop constant bias (or per (loop, level) if levels are not tied) that shifts log|ψ| independently of input, useful for calibrating baselines
- Level tying:
if tie_copies = True, loop Fourier weights are shared across levels by averaging features over the level axis before applying (A, B, …). Otherwise, we learn separate parameters per level
- Dtypes:
param_dtype sets the dtype for all learnable parameters
output_dtype chooses whether the network outputs real or complex log amplitudes. If complex output and use_phase=True, we emit log|ψ| + i * phase. Otherwise we emit a real log|ψ|
- Regularisation:
regularization_terms() returns a dict of L2 penalties for different parameter blocks (Fourier, vertex, edge bias, loop bias) controlled by l2_* scalars. You sum these into your loss externally.
All index tensors (loop edges, triplet edges) and frequencies are built once and stored in the non-trainable “cache” collection to avoid Python loops and recompilations during training.
- class LocalGraphNQS(n_edges, gauge_dim, n_vertices, vertex_edges, edge_hidden=16, vertex_hidden=32, global_hidden=32, n_edge_layers=2, n_vertex_layers=2, n_global_layers=2, use_phase=True, param_dtype=<class 'jax.numpy.float64'>, output_dtype=<class 'jax.numpy.complex128'>, zero_repulsion=0.0, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
ModuleA graph-local, lightweight neural quantum state designed for gauge-theoretic lattice models (e.g. U(1)^3 weak-coupling LQG) defined on a fixed graph.
This ansatz is built to be a minimal yet expressive replacement for an RBM-like architecture, with the following design goals:
Graph-awareness: The model knows which edges are incident on which vertices. Instead of connecting all degrees of freedom to all hidden units (as in a dense RBM), we:
embed edges locally,
aggregate those embeddings per vertex,
then combine vertex features into a single global feature vector.
This aligns more naturally with local operators such as vertex Hamilton constraints and volume operators, which only couple a small set of adjacent edges.
Local expressivity: Each edge’s gauge charges (e.g. 3 integers for U(1)^3) are passed through a small MLP -> “edge embedding”. For each vertex, the embeddings of the edges that touch it are summed -> “vertex embedding”, which is then fed through another small MLP. This gives the network enough capacity to model non-linear, gauge-local correlations without being overkill.
Global summary: Vertex embeddings are summed over all vertices to obtain a global feature vector. This is then processed by a final MLP and read out into:
a scalar log-amplitude log|ψ(σ)| (real),
and optionally a scalar phase arg ψ(σ) (real),
which we pack into a complex log ψ(σ).
This “sum over vertices” plays a similar role to summing hidden units in an RBM, but in a structured, graph-informed way.
Efficiency/small footprint: The network is intentionally low-width (tens of hidden units). The cost is dominated by the local Hamiltonian/constraint evaluation in VMC, not by the neural net. This keeps per iteration runtime reasonable even when the operator (e.g. Thiemann H_v) is expensive.
Avoiding collapse to trivial all σ = 0: The architecture itself does not hard-code any preference for σ = 0. However, the physics of the constraint can favour small flux sectors. To give some control, we provide an optional zero_repulsion term which adds a soft amplitude penalty that grows with ‖σ‖². This can be used to gently discourage trivial/near-trivial configurations while preserving low-energy structure.
Graph/connectivity information¶
We do not store a full graph object inside this module. Instead, we only require a minimal mapping that encodes which edges are incident on which vertices.
The vertex_edges argument is a Python list of length V:
- vertex_edges[v]List[int]
A list of edge indices (0 ≤ e < n_edges) that touch vertex v.
From this, _vertex_const() builds static JAX arrays:
vertex_edges : (V, max_deg), padded with -1 where a vertex has fewer than max_deg incident edges.
vertex_mask : (V, max_deg), with 1.0 at valid edge slots and 0.0 at padding.
The generation of this list is done externally (see make_local_graph_nqs_kwargs()), so that this module stays agnostic of the particular graph library used.
Operationally, the network behaves as follows: Given σ (B, n_edges * gauge_dim):
Edge tower (local per-edge MLP): - Reshape to (B, n_edges, gauge_dim). - Flatten to (B * n_edges, gauge_dim). - Apply n_edge_layers × Dense + SiLU. - Reshape back to (B, n_edges, edge_hidden).
Vertex aggregation: - For each vertex v, gather all incident edge embeddings from step 1 using vertex_edges. - Multiply by vertex_mask to zero-out padding. - Sum over the incident edges -> (B, V, edge_hidden).
Vertex tower (shared vertex MLP): - Flatten to (B * V, edge_hidden). - Apply n_vertex_layers × Dense + SiLU, producing (B * V, vertex_hidden). - Reshape back to (B, V, vertex_hidden).
Global aggregation + head: - Sum over vertices -> (B, vertex_hidden). - Apply n_global_layers × Dense + SiLU -> (B, global_hidden). - Apply a linear amplitude head -> (B,). - Optionally add a soft “zero repulsion” term depending on ‖σ‖². - Optionally apply a linear phase head -> (B,). - Combine into complex log ψ(σ) if use_phase=True.
Notes
This ansatz is deliberately “shallow”: it has no deep message-passing or attention. It trades some expressive power for simplicity and runtime speed.
All graph-specific structure is injected via vertex_edges, so you can reuse this module across different graphs and models by just changing that list and the Hilbert size.
- make_local_graph_nqs_kwargs(graph, H)¶
Construct keyword arguments for LocalGraphNQS from a graph object and a Hilbert object H.
This helper extracts just enough structural information from the graph and H objects to initialise a LocalGraphNQS instance without the network itself depending on the full graph API
In particular, it builds the vertex_edges list-of-lists, which encodes which edges are incident on which vertices
We assume that the graph handler provides:
graph.handler.graph_edges_data[“graph”][“connectivities”]: a mapping (str(vertex) -> dict) describing connectivity info per vertex. We only use its keys to enumerate vertex ids here
graph.handler.list_of_node_connectivity[v]: a dict with keys “outgoing” and “incoming”, each a list of oriented edges (u, w, key) that start or end at node v
graph.edge_to_index(edge): a function mapping an oriented edge key (u, v, key) to a unique integer index in [0, n_edges).
The Hilbert object H is assumed to provide:
H.gauge_dimensions: the gauge_dim (e.g. 3 for U(1)^3).
graph.n_edges: the total number of edges at a single gauge-copy level.
- Returns:
kwargs – Dictionary of keyword arguments compatible with LocalGraphNQS, including:
n_edges
gauge_dim
n_vertices
vertex_edges
and small default widths/depths and dtypes.
- Return type: