neuralqx.operators.computational.wrappers.jax package

class ProductJax(A, B, *, tol=0.0, is_hermitian=False)

Bases: ComputationalJaxOperator

JAX-native, matrix-free product of two padded local operators A (left) and B (right):

\[(A @ B)\,|\sigma\rangle \;=\; A\!\left( B\,|\sigma\rangle \right).\]

The class composes the connection graphs without building dense matrices. What it does is:

  1. Apply B once to the input(s): obtain (σ'_B, m_B) of shape (B, N, C_B, D) and (B, N, C_B)

  2. For each branch index cb of B: - If that branch is globally zero over the batch (\max_{b,n}|m_B[b,n,cb]| \le tol),

    skip it (we will inject zeros in the padded tail).

    • Else apply A to σ'_B[..., cb, :] and combine path matrix elements: m_{path} = m_B[..., cb, None] * m_A.

  3. Concatenate across kept branches, stable-partition so that all nonzero connections come first, and pad the tail with the input state(s) and zeros

Parameters:
  • A (ComputationalJaxOperator) – JAX-native operators to be composed as A @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.

  • B (ComputationalJaxOperator) – JAX-native operators to be composed as A @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.

  • tol (float, optional) – Threshold for treating matrix elements as zero (default 0.0). This affects pruning and the stable partitioning (which connections are considered “valid”)

  • is_hermitian (bool, optional) – If you know A @ B is Hermitian (e.g., A=H, B=H or B=A† and the product is provably Hermitian), set this to True to select covariance gradients downstream

Notes

  • Padded contract: The connection axis has a fixed size per call. All nonzero connections are stably moved to the front. The tail contains copies of the input state with zero matrix element. Downstream Monte-Carlo code should only consume the first counts[b,n] = (|mels[b,n,:]|>tol).sum() connections

  • Batch behavior: Pruning is done with a global keep mask over the batch to avoid data-dependent recompiles, larger batches may retain more branches (others are padded)

  • JAX transforms: The heavy lifting (packing/partitioning) is a single, stateless jax.jit() kernel taking only arrays, there are no Python-side loops or dynamic slices inside jit

Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:

A  = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=False)
Ad = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=True)
HvHvdag = ProductJax(A, Ad, tol=1e-12)

# single input ket (N, D):
sigma_p, mels = HvHvdag.get_conn_padded(sigma)     # (N, C, D), (N, C)
counts = (jnp.abs(mels) > 1e-12).sum(axis=1)       # (N,)

# batched (B, N, D):
sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma)  # (B, N, C, D), (B, N, C)
counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2)          # (B, N)
class SumJax(A, B, *, subtract=False, tol=0.0, is_hermitian=False)

Bases: ComputationalJaxOperator

JAX-native, matrix-free sum or difference of two ComputationalJaxOperators A and B:

\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``}, \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``}. \end{cases}\end{split}\]
Parameters:
  • A (ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.

  • B (ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.

  • subtract (bool, optional) – If True, computes A - B instead of A + B (default: False).

  • tol (float, optional) – Threshold for treating matrix elements as zero (default 0.0).

  • is_hermitian (bool, optional) – Whether to mark this composite operator as Hermitian (default: False).

Notes

  • Uses the padded connection API (NetKet contract): the connection axis has a fixed size per call

  • No Python-side loops, pure JAX operations and pytree registration

  • Works with both single-input (N, D) and batched (B, N, D) configurations

class ScaledJax(op, alpha)

Bases: ComputationalJaxOperator

JAX-native wrapper for scalar multiplication of a ComputationalJaxOperator.

Given an operator O and a scalar α (float/complex, Python/jnp/np), this represents the operator α·O:

(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.

Notes

  • Scales all matrix elements (diagonal and off-diagonal).

  • Dtype is promoted with promote_constant_for_op_dtype(op.dtype, α) and then combined via jnp.result_type.

  • Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.