neuralqx.nn.models package

This module contains the implementation of different neural networks used in the simulations

class RevNet(number_of_cnn_features, number_of_cnn_blocks, lqx, dropout_rate=0.2, free_idxs=None, dtype=<class 'jax.numpy.float64'>, bn_use_running_average=False, use_visible_bias=True, conv_kernel_init=None, conv_bias_init=0.1, bn_scale_init=None, bn_bias_init=0.1, se_fc1_kernel_init=None, se_fc1_bias_init=0.1, se_fc2_kernel_init=None, se_fc2_bias_init=0.1, eval_dense_kernel_init=None, eval_dense_bias_init=0.1, visible_bias_init=0.8, rng_key=None, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

Below is the same RevNet architecture with fully tunable normalisation and initialisers.

What differs here from the above implementaion is: - bn_use_running_average is configurable (True/False) (CURRENTLY NO BN IMPLEMENTED) - per-layer initialiser knobs, each accepting:

  • a callable initialiser (Flax-style),

  • OR a float/stddev (wraps to nn.initializers.normal(stddev)),

  • OR None to use a sensible default.

  • complex-aware initialisation: if dtype is complex, weights are built from two real draws to keep magnitude/phase variance at init

  • activations are now log_cosh, instead of hard_silu

RNG usage:
  • we do NOT pass rng inside the module, use Flax collections:

    params = model.init({‘params’: key}, x)

The old rng_key field is deprecated and ignored

Defaults chosen to introduce gentle variance without saturation:
  • Conv/Dense kernels: normal(stddev=0.05)

  • Biases: zeros

  • BatchNorm scale: ones; bias: zeros

Initializer knobs (per layer)

For each spec below, you can pass:
  • callable initializer (Flax style),

  • or a float/int (interpreted as stddev for normal init),

  • or None for default described here.

conv_kernel_init : InitLike, default normal(0.05) conv_bias_init : InitLike, default zeros

bn_scale_init : InitLike, default ones bn_bias_init : InitLike, default zeros

se_fc1_kernel_init : InitLike, default normal(0.05) se_fc1_bias_init : InitLike, default zeros se_fc2_kernel_init : InitLike, default normal(0.05) se_fc2_bias_init : InitLike, default zeros

eval_dense_kernel_init : InitLike, default normal(0.05) eval_dense_bias_init : InitLike, default zeros

Notes

  • RNG: initialize with model.init({‘params’: key}, x). Different keys give different parameter samples. The deprecated rng_key field is ignored.

  • If you previously depended on BatchNorm with use_running_average=True during training, that can suppress variance. Toggle this via the knob.

class RevNetAuto(number_of_cnn_features, rng_key, lqx, dropout_rate=0.2, free_idxs=None, dtype=<class 'jax.numpy.float64'>, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

The RevNet network implementation in flax

class ResNet(num_blocks=2, block_features=60, main_conv=60, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

A ResNet like neural network composed of several residual blocks with skip connections

class RevNetComplex(number_of_cnn_features, number_of_cnn_blocks, lqx, rng_key, dropout_rate=0.0, free_idxs=None, dtype=<class 'jax.numpy.float64'>, use_layer_norm=False, _default_R_activation=<PjitFunction of <function hard_silu>>, reduction_ratio=2, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

RevNet implementation supporting real and complex dtypes

make_cfvtnet_kwargs(graph, H)

Construct graph/hamiltonian dependent kwargs for CFVTNet.

class CFVTNet(n_edges, n_vertices, gauge_dim, cutoff, q_min, q_max, q_step, edge_to_index, loops_edges_typed=None, loops_indices=None, triplets_by_vertex=None, triplet_signs_by_vertex=None, N_F=1, freq_mode='harmonic', freq_log_base=2.0, tie_copies=False, use_vertex_odd=True, use_vertex_even=True, use_phase=True, use_edge_bias=False, use_loop_bias=False, param_dtype=<class 'jax.numpy.float64'>, output_dtype=<class 'jax.numpy.complex128'>, l2_fourier=0.0, l2_vertex=0.0, l2_edge=0.0, l2_loop=0.0, init_std_A=0.1, init_std_B=0.1, init_std_A_phase=0.1, init_std_B_phase=0.1, init_std_alpha=0.1, init_std_lambda=0.1, init_std_beta=0.1, init_std_alpha_phase=0.1, init_std_visible_bias=0.4, init_std_visible_bias_phase=0.5, init_std_loop_bias=0.0, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

A fast gauge-equivariant ansatz for lattice gauge states.

This module implements a composite ansatz for (log-) wavefunction amplitudes on a tensor-product Hilbert space built over K gauge levels (dimensions, e.g. for U(1)^3 this would be 3) of a base graph. Each level corresponds to a block of n_edges oriented edges, the total visible width is n_edges * gauge_dimensions.

The architecture fuses three ingredients:

  1. Loop Fourier features
    • we extract loop charges q_{ℓ,k} by summing ±σ_e over the oriented edges e that form each minimal loop ℓ, for each level k

    • we then compute periodic features cos(ω_j q_{ℓ,k}), sin(ω_j q_{ℓ,k}) where ω_j are harmonics over the cyclic group Z_M (M = 2 * cutoff + 1)

    • trainable weights A, B (and optionally A_phase, B_phase for complex phase) linearly combine these features into a scalar contribution to log|ψ| (and arg ψ)

  2. Local, volume-like triplet term (optional, requires K = 3)
    • at each vertex, we may provide oriented triplets of incident edges (e1, e2, e)

    • for each triplet, we form a pseudoscalar Tv := ε · (σ_{e1} · (σ_{e2} × σ_{e3})), this couples the three gauge levels as a triple product

    • odd (tanh with learned scale λ) and even (quadratic penalty) parts contribute to the amplitude, an optional linear part contributes to the phase

  3. Bias terms (optional)
    • per-edge linear bias (RBM style, Σ_e b_e σ_e) adds directly to log|ψ|, with an optional phase counterpart

    • a per-loop constant bias (or per (loop, level) if levels are not tied) that shifts log|ψ| independently of input, useful for calibrating baselines

Level tying:
  • if tie_copies = True, loop Fourier weights are shared across levels by averaging features over the level axis before applying (A, B, …). Otherwise, we learn separate parameters per level

Dtypes:
  • param_dtype sets the dtype for all learnable parameters

  • output_dtype chooses whether the network outputs real or complex log amplitudes. If complex output and use_phase=True, we emit log|ψ| + i * phase. Otherwise we emit a real log|ψ|

Regularisation:
  • regularization_terms() returns a dict of L2 penalties for different parameter blocks (Fourier, vertex, edge bias, loop bias) controlled by l2_* scalars. You sum these into your loss externally.

All index tensors (loop edges, triplet edges) and frequencies are built once and stored in the non-trainable “cache” collection to avoid Python loops and recompilations during training.

class LocalGraphNQS(n_edges, gauge_dim, n_vertices, vertex_edges, edge_hidden=16, vertex_hidden=32, global_hidden=32, n_edge_layers=2, n_vertex_layers=2, n_global_layers=2, use_phase=True, param_dtype=<class 'jax.numpy.float64'>, output_dtype=<class 'jax.numpy.complex128'>, zero_repulsion=0.0, parent=<flax.linen.module._Sentinel object>, name=None)

Bases: Module

A graph-local, lightweight neural quantum state designed for gauge-theoretic lattice models (e.g. U(1)^3 weak-coupling LQG) defined on a fixed graph.

This ansatz is built to be a minimal yet expressive replacement for an RBM-like architecture, with the following design goals:

  • Graph-awareness: The model knows which edges are incident on which vertices. Instead of connecting all degrees of freedom to all hidden units (as in a dense RBM), we:

    • embed edges locally,

    • aggregate those embeddings per vertex,

    • then combine vertex features into a single global feature vector.

    This aligns more naturally with local operators such as vertex Hamilton constraints and volume operators, which only couple a small set of adjacent edges.

  • Local expressivity: Each edge’s gauge charges (e.g. 3 integers for U(1)^3) are passed through a small MLP -> “edge embedding”. For each vertex, the embeddings of the edges that touch it are summed -> “vertex embedding”, which is then fed through another small MLP. This gives the network enough capacity to model non-linear, gauge-local correlations without being overkill.

  • Global summary: Vertex embeddings are summed over all vertices to obtain a global feature vector. This is then processed by a final MLP and read out into:

    • a scalar log-amplitude log|ψ(σ)| (real),

    • and optionally a scalar phase arg ψ(σ) (real),

    which we pack into a complex log ψ(σ).

    This “sum over vertices” plays a similar role to summing hidden units in an RBM, but in a structured, graph-informed way.

  • Efficiency/small footprint: The network is intentionally low-width (tens of hidden units). The cost is dominated by the local Hamiltonian/constraint evaluation in VMC, not by the neural net. This keeps per iteration runtime reasonable even when the operator (e.g. Thiemann H_v) is expensive.

  • Avoiding collapse to trivial all σ = 0: The architecture itself does not hard-code any preference for σ = 0. However, the physics of the constraint can favour small flux sectors. To give some control, we provide an optional zero_repulsion term which adds a soft amplitude penalty that grows with ‖σ‖². This can be used to gently discourage trivial/near-trivial configurations while preserving low-energy structure.

Graph/connectivity information

We do not store a full graph object inside this module. Instead, we only require a minimal mapping that encodes which edges are incident on which vertices.

The vertex_edges argument is a Python list of length V:

vertex_edges[v]List[int]

A list of edge indices (0 ≤ e < n_edges) that touch vertex v.

From this, _vertex_const() builds static JAX arrays:

  • vertex_edges : (V, max_deg), padded with -1 where a vertex has fewer than max_deg incident edges.

  • vertex_mask : (V, max_deg), with 1.0 at valid edge slots and 0.0 at padding.

The generation of this list is done externally (see make_local_graph_nqs_kwargs()), so that this module stays agnostic of the particular graph library used.

Operationally, the network behaves as follows: Given σ (B, n_edges * gauge_dim):

  1. Edge tower (local per-edge MLP): - Reshape to (B, n_edges, gauge_dim). - Flatten to (B * n_edges, gauge_dim). - Apply n_edge_layers × Dense + SiLU. - Reshape back to (B, n_edges, edge_hidden).

  2. Vertex aggregation: - For each vertex v, gather all incident edge embeddings from step 1 using vertex_edges. - Multiply by vertex_mask to zero-out padding. - Sum over the incident edges -> (B, V, edge_hidden).

  3. Vertex tower (shared vertex MLP): - Flatten to (B * V, edge_hidden). - Apply n_vertex_layers × Dense + SiLU, producing (B * V, vertex_hidden). - Reshape back to (B, V, vertex_hidden).

  4. Global aggregation + head: - Sum over vertices -> (B, vertex_hidden). - Apply n_global_layers × Dense + SiLU -> (B, global_hidden). - Apply a linear amplitude head -> (B,). - Optionally add a soft “zero repulsion” term depending on ‖σ‖². - Optionally apply a linear phase head -> (B,). - Combine into complex log ψ(σ) if use_phase=True.

Notes

  • This ansatz is deliberately “shallow”: it has no deep message-passing or attention. It trades some expressive power for simplicity and runtime speed.

  • All graph-specific structure is injected via vertex_edges, so you can reuse this module across different graphs and models by just changing that list and the Hilbert size.

make_local_graph_nqs_kwargs(graph, H)

Construct keyword arguments for LocalGraphNQS from a graph object and a Hilbert object H.

This helper extracts just enough structural information from the graph and H objects to initialise a LocalGraphNQS instance without the network itself depending on the full graph API

In particular, it builds the vertex_edges list-of-lists, which encodes which edges are incident on which vertices

We assume that the graph handler provides:

  • graph.handler.graph_edges_data[“graph”][“connectivities”]: a mapping (str(vertex) -> dict) describing connectivity info per vertex. We only use its keys to enumerate vertex ids here

  • graph.handler.list_of_node_connectivity[v]: a dict with keys “outgoing” and “incoming”, each a list of oriented edges (u, w, key) that start or end at node v

  • graph.edge_to_index(edge): a function mapping an oriented edge key (u, v, key) to a unique integer index in [0, n_edges).

The Hilbert object H is assumed to provide:

  • H.gauge_dimensions: the gauge_dim (e.g. 3 for U(1)^3).

  • graph.n_edges: the total number of edges at a single gauge-copy level.

Returns:

kwargs – Dictionary of keyword arguments compatible with LocalGraphNQS, including:

  • n_edges

  • gauge_dim

  • n_vertices

  • vertex_edges

  • and small default widths/depths and dtypes.

Return type:

dict

Submodules