neuralqx.utils.numbers module¶
- is_scalar_like(x)¶
True for Python numbers, NumPy scalars, and 0-d arrays (NumPy/jax).
- Return type:
- canonical_jax_dtype(x)¶
Try to return a jnp.dtype for x (dtype, scalar, array, or None). Returns None if not deducible.
- promote_constant_for_op_dtype(op_dtype, const, *, rtol=1e-12, atol=0.0)¶
Given an operator dtype and a scalar-like const, return (target_dtype, const_cast) where const_cast is a 0-d JAX array of target_dtype.
- Rules:
If operator is complex -> cast const to operator dtype (promote real->complex).
- If operator is real:
If const is real -> cast to operator dtype.
If const is complex with non-zero imag (beyond tolerance) -> error. (User should cast operator to complex or pass a real constant.)
If const is complex with ~zero imag -> drop imag safely and cast to op dtype.
This keeps the operator dtype stable unless we explicitly choose to upcast elsewhere.
- hermitian_flag_for_sum_with_scalar(op_is_hermitian, const_cast, *, rtol=1e-12, atol=0.0)¶
- Return type:
- For (op + c·I), result is Hermitian iff:
op is Hermitian, AND
imag(c) ~ 0 (within tolerance).