neuralqx.utils.numbers module

is_scalar_like(x)

True for Python numbers, NumPy scalars, and 0-d arrays (NumPy/jax).

Return type:

bool

canonical_jax_dtype(x)

Try to return a jnp.dtype for x (dtype, scalar, array, or None). Returns None if not deducible.

Return type:

dtype | None

promote_constant_for_op_dtype(op_dtype, const, *, rtol=1e-12, atol=0.0)

Given an operator dtype and a scalar-like const, return (target_dtype, const_cast) where const_cast is a 0-d JAX array of target_dtype.

Return type:

Tuple[dtype, Array]

Rules:
  • If operator is complex -> cast const to operator dtype (promote real->complex).

  • If operator is real:
    • If const is real -> cast to operator dtype.

    • If const is complex with non-zero imag (beyond tolerance) -> error. (User should cast operator to complex or pass a real constant.)

    • If const is complex with ~zero imag -> drop imag safely and cast to op dtype.

This keeps the operator dtype stable unless we explicitly choose to upcast elsewhere.

hermitian_flag_for_sum_with_scalar(op_is_hermitian, const_cast, *, rtol=1e-12, atol=0.0)
Return type:

bool

For (op + c·I), result is Hermitian iff:
  • op is Hermitian, AND

  • imag(c) ~ 0 (within tolerance).

to_python_scalar(x0d)

Convert a 0-d JAX array to a Python scalar (float/complex/int).

Return type:

Number