neuralqx.experimental.operators.symbolic.compiler.lowering.jax_lowerer module¶
JAX backend symbolic operator lowerer.
Converts a SymbolicOperatorIR
into a concrete CompiledOperator
whose _get_conn_padded kernel is built by interpreting the
AmplitudeExpr /
PredicateExpr /
UpdateProgram expression trees
as JAX operations at trace time.
Architecture¶
For each IR term the lowerer generates a term runner: a Python function that,
given a single input configuration x (shape [hilbert_size]), returns a
tuple (x_primes, mels, valids) of shape
(fanout, hilbert_size), (fanout,), (fanout,) respectively.
K-body terms use a static index_array (shape [M, K]) and jax.vmap
over its rows. Each row instantiates the iterator-label environment and
evaluates all emissions, producing E branches per row (total M * E).
Branch-multiset note¶
Duplicate x' values are not coalesced. If multiple terms or emissions
produce the same connected state, they appear as separate rows in the padded
output.
- class JAXSymbolicLowerer¶
Bases:
AbstractSymbolicLowererJAX-backend symbolic operator lowerer.