neuralqx.operators.computational.wrappers.jax package¶
- class ProductJax(A, B, *, tol=0.0, is_hermitian=False)¶
Bases:
ComputationalJaxOperatorJAX-native, matrix-free product of two padded local operators
A(left) andB(right):\[(A @ B)\,|\sigma\rangle \;=\; A\!\left( B\,|\sigma\rangle \right).\]The class composes the connection graphs without building dense matrices. What it does is:
Apply
Bonce to the input(s): obtain(σ'_B, m_B)of shape(B, N, C_B, D)and(B, N, C_B)For each branch index
cbofB: - If that branch is globally zero over the batch (\max_{b,n}|m_B[b,n,cb]| \le tol),skip it (we will inject zeros in the padded tail).
Else apply
Atoσ'_B[..., cb, :]and combine path matrix elements:m_{path} = m_B[..., cb, None] * m_A.
Concatenate across kept branches, stable-partition so that all nonzero connections come first, and pad the tail with the input state(s) and zeros
- Parameters:
A (
ComputationalJaxOperator) – JAX-native operators to be composed asA @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.B (
ComputationalJaxOperator) – JAX-native operators to be composed asA @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.tol (
float, optional) – Threshold for treating matrix elements as zero (default0.0). This affects pruning and the stable partitioning (which connections are considered “valid”)is_hermitian (
bool, optional) – If you knowA @ Bis Hermitian (e.g.,A=H,B=HorB=A†and the product is provably Hermitian), set this toTrueto select covariance gradients downstream
Notes
Padded contract: The connection axis has a fixed size per call. All nonzero connections are stably moved to the front. The tail contains copies of the input state with zero matrix element. Downstream Monte-Carlo code should only consume the first
counts[b,n] = (|mels[b,n,:]|>tol).sum()connectionsBatch behavior: Pruning is done with a global keep mask over the batch to avoid data-dependent recompiles, larger batches may retain more branches (others are padded)
JAX transforms: The heavy lifting (packing/partitioning) is a single, stateless
jax.jit()kernel taking only arrays, there are no Python-side loops or dynamic slices inside jit
Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:
A = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=False) Ad = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=True) HvHvdag = ProductJax(A, Ad, tol=1e-12) # single input ket (N, D): sigma_p, mels = HvHvdag.get_conn_padded(sigma) # (N, C, D), (N, C) counts = (jnp.abs(mels) > 1e-12).sum(axis=1) # (N,) # batched (B, N, D): sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma) # (B, N, C, D), (B, N, C) counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2) # (B, N)
- class SumJax(A, B, *, subtract=False, tol=0.0, is_hermitian=False)¶
Bases:
ComputationalJaxOperatorJAX-native, matrix-free sum or difference of two ComputationalJaxOperators
AandB:\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``}, \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``}. \end{cases}\end{split}\]- Parameters:
A (
ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.B (
ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.subtract (
bool, optional) – IfTrue, computesA - Binstead ofA + B(default:False).tol (
float, optional) – Threshold for treating matrix elements as zero (default0.0).is_hermitian (
bool, optional) – Whether to mark this composite operator as Hermitian (default:False).
Notes
Uses the padded connection API (NetKet contract): the connection axis has a fixed size per call
No Python-side loops, pure JAX operations and pytree registration
Works with both single-input
(N, D)and batched(B, N, D)configurations
- class ScaledJax(op, alpha)¶
Bases:
ComputationalJaxOperatorJAX-native wrapper for scalar multiplication of a ComputationalJaxOperator.
Given an operator
Oand a scalarα(float/complex, Python/jnp/np), this represents the operatorα·O:(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.
Notes
Scales all matrix elements (diagonal and off-diagonal).
Dtype is promoted with
promote_constant_for_op_dtype(op.dtype, α)and then combined viajnp.result_type.Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.