neuralqx.utils.jit module¶
This file contains numba jitted implementations of various functions
- mod_add_njit(m, n, q_min, q_max, step=1)¶
A function which imposes the q-deformed representation label addition for the \(U_q(1)\) group. The modded addition follows the following prescription:
\[m \oplus n := (m + n + N) mod(2N + 1) - N\]for two representation labels m and n and a cutoff of N.
- Parameters:
m (
Union[ndarray,int,float]) – an int or float representing a representation label, or a matrixn (
Union[ndarray,int,float]) – an int or float representing a representation label, or a matrixq_min (
Union[ndarray,int,float]) – the smallest allowed DoFq_max (
Union[ndarray,int,float]) – the largest allowed DoFstep (
Union[ndarray,int,float]) – the step size from one DoF to another
- mod_add_jax(m, n, *, q_min=None, q_max=None, step=1, cutoff=None)¶
Perform modular addition over a discrete quantum-number domain in a JAX-friendly way.
This is the JAX-safe counterpart of
neuralqx.utils.misc.arithmetic.mod_add(). It computesm (+) nand wraps the result into the specified lattice domain using modular arithmetic, while keeping dtypes stable under JAX transformations (jit,vmap,grad).Domain specification:
If
q_minandq_maxare provided, the allowed domain is interpreted as the lattice{q_min + k * step}with endpoints[q_min, q_max](inclusive).Otherwise, if
cutoffis provided, the domain is the symmetric interval[-cutoff, cutoff]using the givenstep.
The wrapping formula used is:
res = ((m + n - q_min) % (q_max - q_min + step)) + q_minDtype behavior:
Computation is performed in an “operation dtype” that can represent the domain parameters.
If the probed domain is integer-valued (based on
q_minandstep) and the result type ofmandnis integer, the result is cast back to that integer dtype. Otherwise, the result remains in the operation dtype.
- Parameters:
m – First addend. Can be a scalar or array; must be broadcast-compatible with
n.n – Second addend. Can be a scalar or array; must be broadcast-compatible with
m.q_min – Lower bound of the domain (inclusive). Must be provided together with
q_maxunlesscutoffis used.q_max – Upper bound of the domain (inclusive). Must be provided together with
q_minunlesscutoffis used.step – Lattice spacing of the domain. Defaults to 1.
cutoff – If provided and
q_min/q_maxare not, setsq_min=-cutoffandq_max=cutoff.
- Returns:
The modular sum wrapped into the specified domain, with the broadcasted shape of
mandn.- Raises:
ValueError – If neither
(q_min, q_max)norcutoffare provided.
- mod_sum_jax(values, *, q_min=None, q_max=None, step=1, cutoff=None)¶
Compute a modular sum over the last axis of
valuesin a JAX-transformation-friendly way.This function reduces
valuesalongaxis=-1and wraps the result back into a discrete quantum-number domain using modular arithmetic. It is designed to work under JAX transformations (jit,vmap,grad) by avoiding control flow that would cause dtype/shape changes during tracing.Domain specification:
If
q_minandq_maxare provided, the allowed domain is interpreted as the lattice{q_min + k * step}with endpoints[q_min, q_max](inclusive).Otherwise, if
cutoffis provided, the domain is the symmetric interval[-cutoff, cutoff]using the givenstep.
The wrapping formula used is:
out = ((s - q_min) % (q_max - q_min + step)) + q_minwhere
s = sum(values, axis=-1).Dtype behavior:
Computation is performed in an “operation dtype” that can safely represent the modular arithmetic with the provided domain parameters.
If the probed domain is integer-valued (based on
q_minandstep) andvalueshas an integer dtype, the output is cast back tovalues.dtype. Otherwise, the output remains in the operation dtype.
- Parameters:
values – Array of values to modular-sum. The reduction is performed over the last axis. Can be 1D (returns a scalar) or batched with arbitrary leading dimensions.
q_min – Lower bound of the domain (inclusive). Must be provided together with
q_maxunlesscutoffis used.q_max – Upper bound of the domain (inclusive). Must be provided together with
q_minunlesscutoffis used.step – Lattice spacing of the domain. Defaults to 1.
cutoff – If provided and
q_min/q_maxare not, setsq_min=-cutoffandq_max=cutoff.
- Returns:
The modular sum wrapped into the specified domain. Shape is
values.shape[:-1].- Raises:
ValueError – If neither
(q_min, q_max)norcutoffare provided.