neuralqx.utils.jit module

This file contains numba jitted implementations of various functions

mod_add_njit(m, n, q_min, q_max, step=1)

A function which imposes the q-deformed representation label addition for the \(U_q(1)\) group. The modded addition follows the following prescription:

\[m \oplus n := (m + n + N) mod(2N + 1) - N\]

for two representation labels m and n and a cutoff of N.

Parameters:
  • m (Union[ndarray, int, float]) – an int or float representing a representation label, or a matrix

  • n (Union[ndarray, int, float]) – an int or float representing a representation label, or a matrix

  • q_min (Union[ndarray, int, float]) – the smallest allowed DoF

  • q_max (Union[ndarray, int, float]) – the largest allowed DoF

  • step (Union[ndarray, int, float]) – the step size from one DoF to another

mod_add_jax(m, n, *, q_min=None, q_max=None, step=1, cutoff=None)

Perform modular addition over a discrete quantum-number domain in a JAX-friendly way.

This is the JAX-safe counterpart of neuralqx.utils.misc.arithmetic.mod_add(). It computes m (+) n and wraps the result into the specified lattice domain using modular arithmetic, while keeping dtypes stable under JAX transformations (jit, vmap, grad).

Domain specification:

  • If q_min and q_max are provided, the allowed domain is interpreted as the lattice {q_min + k * step} with endpoints [q_min, q_max] (inclusive).

  • Otherwise, if cutoff is provided, the domain is the symmetric interval [-cutoff, cutoff] using the given step.

The wrapping formula used is:

res = ((m + n - q_min) % (q_max - q_min + step)) + q_min

Dtype behavior:

  • Computation is performed in an “operation dtype” that can represent the domain parameters.

  • If the probed domain is integer-valued (based on q_min and step) and the result type of m and n is integer, the result is cast back to that integer dtype. Otherwise, the result remains in the operation dtype.

Parameters:
  • m – First addend. Can be a scalar or array; must be broadcast-compatible with n.

  • n – Second addend. Can be a scalar or array; must be broadcast-compatible with m.

  • q_min – Lower bound of the domain (inclusive). Must be provided together with q_max unless cutoff is used.

  • q_max – Upper bound of the domain (inclusive). Must be provided together with q_min unless cutoff is used.

  • step – Lattice spacing of the domain. Defaults to 1.

  • cutoff – If provided and q_min/q_max are not, sets q_min=-cutoff and q_max=cutoff.

Returns:

The modular sum wrapped into the specified domain, with the broadcasted shape of m and n.

Raises:

ValueError – If neither (q_min, q_max) nor cutoff are provided.

mod_sum_jax(values, *, q_min=None, q_max=None, step=1, cutoff=None)

Compute a modular sum over the last axis of values in a JAX-transformation-friendly way.

This function reduces values along axis=-1 and wraps the result back into a discrete quantum-number domain using modular arithmetic. It is designed to work under JAX transformations (jit, vmap, grad) by avoiding control flow that would cause dtype/shape changes during tracing.

Domain specification:

  • If q_min and q_max are provided, the allowed domain is interpreted as the lattice {q_min + k * step} with endpoints [q_min, q_max] (inclusive).

  • Otherwise, if cutoff is provided, the domain is the symmetric interval [-cutoff, cutoff] using the given step.

The wrapping formula used is:

out = ((s - q_min) % (q_max - q_min + step)) + q_min

where s = sum(values, axis=-1).

Dtype behavior:

  • Computation is performed in an “operation dtype” that can safely represent the modular arithmetic with the provided domain parameters.

  • If the probed domain is integer-valued (based on q_min and step) and values has an integer dtype, the output is cast back to values.dtype. Otherwise, the output remains in the operation dtype.

Parameters:
  • values – Array of values to modular-sum. The reduction is performed over the last axis. Can be 1D (returns a scalar) or batched with arbitrary leading dimensions.

  • q_min – Lower bound of the domain (inclusive). Must be provided together with q_max unless cutoff is used.

  • q_max – Upper bound of the domain (inclusive). Must be provided together with q_min unless cutoff is used.

  • step – Lattice spacing of the domain. Defaults to 1.

  • cutoff – If provided and q_min/q_max are not, sets q_min=-cutoff and q_max=cutoff.

Returns:

The modular sum wrapped into the specified domain. Shape is values.shape[:-1].

Raises:

ValueError – If neither (q_min, q_max) nor cutoff are provided.