neuralqx.experimental.operators.symbolic.compiler.lowering.jax_lowerer module

JAX backend symbolic operator lowerer.

Converts a SymbolicOperatorIR into a concrete CompiledOperator whose _get_conn_padded kernel is built by interpreting the AmplitudeExpr / PredicateExpr / UpdateProgram expression trees as JAX operations at trace time.

Architecture

For each IR term the lowerer generates a term runner: a Python function that, given a single input configuration x (shape [hilbert_size]), returns a tuple (x_primes, mels, valids) of shape (fanout, hilbert_size), (fanout,), (fanout,) respectively.

K-body terms use a static index_array (shape [M, K]) and jax.vmap over its rows. Each row instantiates the iterator-label environment and evaluates all emissions, producing E branches per row (total M * E).

Branch-multiset note

Duplicate x' values are not coalesced. If multiple terms or emissions produce the same connected state, they appear as separate rows in the padded output.

class JAXSymbolicLowerer

Bases: AbstractSymbolicLowerer

JAX-backend symbolic operator lowerer.