neuralqx.utils.misc.arithmetic module

mod_add(m, n, *, q_min=None, q_max=None, step=1, cutoff=None)

Modular addition over a discrete quantum-number domain.

This function computes m (+) n and wraps the result back into the allowed domain using modular arithmetic. The domain is interpreted as a lattice {q_min + k * step} with endpoints [q_min, q_max] (inclusive). If q_min/q_max are not provided, a symmetric domain [-cutoff, cutoff] is used.

The wrapping formula is:

res = ((m + n - q_min) % (q_max - q_min + step)) + q_min

Dtype behavior:

  • The function probes a short prefix of the lattice (three points) to determine whether the domain is integer-valued.

  • If the domain is integer-valued, the result is cast back to the NumPy/JAX result type of m and n (e.g. to preserve integer dtypes for integer domains).

  • If the domain is fractional (e.g. half-integer), the result is returned without forcing an integer cast.

Parameters:
  • m (Union[ndarray, int, float, Array]) – First addend. Can be a scalar or array; must be broadcast-compatible with n.

  • n (Union[ndarray, int, float, Array]) – Second addend. Can be a scalar or array; must be broadcast-compatible with m.

  • q_min (Union[int, float]) – Lower bound of the domain (inclusive). Must be provided together with q_max unless cutoff is used.

  • q_max (Union[int, float]) – Upper bound of the domain (inclusive). Must be provided together with q_min unless cutoff is used.

  • step (Union[int, float]) – Lattice spacing of the domain. Defaults to 1.

  • cutoff (Union[int, float, None]) – If provided and q_min/q_max are not, sets q_min=-cutoff and q_max=cutoff.

Return type:

Union[ndarray, int, float, Array]

Returns:

The modular sum wrapped into the specified domain, with the broadcasted shape of m and n.

Raises:

ValueError – If neither (q_min, q_max) nor cutoff are provided.

mod_sum(array, *, q_min=None, q_max=None, step=1, cutoff=None)

Compute a modular sum of an iterable using mod_add().

This function behaves like Python’s sum(), except that addition is performed with modular wrapping into the specified quantum-number domain. Summation is performed sequentially (left-fold) using:

s <- mod_add(s, e, ...)

Domain specification:

  • If q_min and q_max are provided, the allowed domain is the lattice {q_min + k * step} with endpoints [q_min, q_max] (inclusive).

  • Otherwise, if cutoff is provided, the domain is the symmetric interval [-cutoff, cutoff] using the given step.

Parameters:
  • array – Iterable of values to modular-sum.

  • q_min – Lower bound of the domain (inclusive). Must be provided together with q_max unless cutoff is used.

  • q_max – Upper bound of the domain (inclusive). Must be provided together with q_min unless cutoff is used.

  • step – Lattice spacing of the domain. Defaults to 1.

  • cutoff – If provided and q_min/q_max are not, sets q_min=-cutoff and q_max=cutoff.

Returns:

The modular sum wrapped into the specified domain.

Raises:

ValueError – If neither (q_min, q_max) nor cutoff are provided.

get_sgn(n)

Return the sign of a number as -1 or +1.

Note: zero is treated as non-negative and returns +1.

Parameters:

n (Union[int, float]) – Number whose sign should be returned.

Return type:

int

Returns:

-1 if n < 0, otherwise +1.

get_signed_value(s)

Parse a leading sign from a string and return it as -1 or +1.

This helper inspects only the first character of the string: if s starts with "-" it returns -1, otherwise it returns +1. It does not validate that the rest of the string is numeric.

Parameters:

s (str) – Input string potentially beginning with a minus sign.

Return type:

int

Returns:

-1 if s starts with "-", otherwise +1.

factorial(n)

Compute the factorial of a non-negative integer.

This implementation uses recursion and returns n!. For n == 0 the function returns 1.

Parameters:

n (int) – Non-negative integer for which to compute the factorial.

Return type:

int

Returns:

The factorial n!.

Raises:

ValueError – If n is negative.

generate_plus_minus_one(key, shape=())

Generate random values in {+1, -1} using JAX PRNG.

The function samples integer bits in {0, 1} and maps them to {-1, +1} via 2 * bits - 1. The output is compatible with JAX transformations.

Parameters:
  • key – JAX PRNGKey used for sampling.

  • shape – Output shape. Defaults to () (a scalar).

Returns:

A JAX array of the given shape containing only -1 and +1.

plus_key(t, number)

Add a constant to all but the last element of a tuple/list key.

This helper returns a new tuple where number is added to each element of t[:-1]. The last element t[-1] is preserved unchanged. This is useful when the last component is a label/metadata entry that should not be shifted.

Parameters:
  • t (Union[tuple, list]) – Input tuple/list. The last element is kept unchanged.

  • number (float) – Constant to add to each element of t[:-1].

Return type:

tuple

Returns:

A new tuple with shifted elements and the original last element.

minus_key(t, number)

Subtract a constant from all but the last element of a tuple/list key.

This helper returns a new tuple where number is subtracted from each element of t[:-1]. The last element t[-1] is preserved unchanged. This is useful when the last component is a label/metadata entry that should not be shifted.

Parameters:
  • t (Union[tuple, list]) – Input tuple/list. The last element is kept unchanged.

  • number (float) – Constant to subtract from each element of t[:-1].

Return type:

tuple

Returns:

A new tuple with shifted elements and the original last element.