neuralqx.operators.computational.wrappers package

class Product(A, B, *, tol=0.0, is_hermitian=False)

Bases: ComputationalOperator

Matrix-free product of two local ComputationalOperators A (right) and B (left):

\[(A @ B)\,|\sigma\rangle \;=\; A\big( B\,|\sigma\rangle \big).\]

This class does not form dense matrices. Instead, it composes the connection graphs provided by A and B. We:

  1. Apply B once to the input ket(s) to obtain all connected states and matrix elements

  2. For each kept branch produced by B, apply A to those branch states

  3. Multiply the path matrix elements and return the final connected states

To mitigate potential memory blow-up from too many connections, the implementation:

  • Early-prunes whole branches from B that are numerically zero for the entire batch: a branch index cb is kept iff any_{b,n} (abs(mB[b,n,cb]) > tol). This avoids ever calling A on completely useless branches

  • Stable-partitions the final connections so that all nonzeros come first along the connection axis, and only the tail is padded with (σ, 0.0) to preserve NetKet’s padded contract (fixed shapes)

  • Works for both a single input ket of shape (N, D) and a batch (B, N, D). The single-input path simply promotes to (1, N, D) and drops the batch dimension at the end

Parameters:
  • A (ComputationalOperator) – Instances of ComputationalOperator. A acts on the left, B on the right (i.e., this class represents A @ B). Their Hilbert spaces must match

  • B (ComputationalOperator) – Instances of ComputationalOperator. A acts on the left, B on the right (i.e., this class represents A @ B). Their Hilbert spaces must match

  • tol (float, optional) – Numerical threshold (default: 0.0). Any connection with |mel| <= tol is treated as zero for the purposes of pruning and partitioning

Notes

Padded API contract.

NetKet’s operator application returns tensors with a fixed connection axis length C for a whole call (and a whole batch). If different items/samples would produce a different number of connections, we must still return a uniform shape, therefore we:

  • move all nonzeros to the front,

  • pad the tail with the input configurations and zero matrix elements

In principle downstream code should only consume the leading valid connections. You can compute the per-sample valid counts on the fly via:

counts = (jnp.abs(mels) > tol).sum(axis=2)  # (B, N)

and use only mels[b, n, :counts[b,n]] and σp[b, n, :counts[b,n], :].

Why batch shapes differ from single:

In batch mode we keep the union of B-branches that are needed by any example in the batch. Therefore the connection axis length C can be larger for larger batches. For examples that do not require some of those branches, the corresponding slots are in the padded tail (zeros).

Hermiticity:

Even if A and B are Hermitian individually, the product A @ B need not be Hermitian. Therefore, we default to Hermiticity being False, but this can be set at initialisation as well.

Examples

Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:

A  = ThiemannRegularisedVertexConstraint(lqx.model, vertex, apply_lapse=True, adjoint=False)
Ad = ThiemannRegularisedVertexConstraint(lqx.model, vertex, apply_lapse=True, adjoint=True)
HvHvdag = Product(A, Ad, tol=1e-12)

# single input ket (N, D):
sigma_p, mels = HvHvdag.get_conn_padded(sigma)     # (N, C, D), (N, C)
counts = (jnp.abs(mels) > 1e-12).sum(axis=1)       # (N,)

# batched (B, N, D):
sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma)  # (B, N, C, D), (B, N, C)
counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2)          # (B, N)

Complexity

Let C_B be the number of connections emitted by B and C_A those of A. After pruning keeps K_B <= C_B branches, the total connections returned are approximately C = K_B * C_A per input chain position. Memory scales with C.

class ProductJax(A, B, *, tol=0.0, is_hermitian=False)

Bases: ComputationalJaxOperator

JAX-native, matrix-free product of two padded local operators A (left) and B (right):

\[(A @ B)\,|\sigma\rangle \;=\; A\!\left( B\,|\sigma\rangle \right).\]

The class composes the connection graphs without building dense matrices. What it does is:

  1. Apply B once to the input(s): obtain (σ'_B, m_B) of shape (B, N, C_B, D) and (B, N, C_B)

  2. For each branch index cb of B: - If that branch is globally zero over the batch (\max_{b,n}|m_B[b,n,cb]| \le tol),

    skip it (we will inject zeros in the padded tail).

    • Else apply A to σ'_B[..., cb, :] and combine path matrix elements: m_{path} = m_B[..., cb, None] * m_A.

  3. Concatenate across kept branches, stable-partition so that all nonzero connections come first, and pad the tail with the input state(s) and zeros

Parameters:
  • A (ComputationalJaxOperator) – JAX-native operators to be composed as A @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.

  • B (ComputationalJaxOperator) – JAX-native operators to be composed as A @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.

  • tol (float, optional) – Threshold for treating matrix elements as zero (default 0.0). This affects pruning and the stable partitioning (which connections are considered “valid”)

  • is_hermitian (bool, optional) – If you know A @ B is Hermitian (e.g., A=H, B=H or B=A† and the product is provably Hermitian), set this to True to select covariance gradients downstream

Notes

  • Padded contract: The connection axis has a fixed size per call. All nonzero connections are stably moved to the front. The tail contains copies of the input state with zero matrix element. Downstream Monte-Carlo code should only consume the first counts[b,n] = (|mels[b,n,:]|>tol).sum() connections

  • Batch behavior: Pruning is done with a global keep mask over the batch to avoid data-dependent recompiles, larger batches may retain more branches (others are padded)

  • JAX transforms: The heavy lifting (packing/partitioning) is a single, stateless jax.jit() kernel taking only arrays, there are no Python-side loops or dynamic slices inside jit

Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:

A  = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=False)
Ad = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=True)
HvHvdag = ProductJax(A, Ad, tol=1e-12)

# single input ket (N, D):
sigma_p, mels = HvHvdag.get_conn_padded(sigma)     # (N, C, D), (N, C)
counts = (jnp.abs(mels) > 1e-12).sum(axis=1)       # (N,)

# batched (B, N, D):
sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma)  # (B, N, C, D), (B, N, C)
counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2)          # (B, N)
class Sum(A, B, *, subtract=False, tol=0.0, is_hermitian=False)

Bases: ComputationalOperator

Matrix-free sum or difference of two local ComputationalOperators A and B:

\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``} \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``} \end{cases}\end{split}\]

This class does not construct dense matrices, it merges the connection graphs of A and B and sums (or subtracts) their matrix elements. Padded output follows NetKet’s fixed-shape contract.

Parameters:
  • A (ComputationalOperator) – Instances of ComputationalOperator. Their Hilbert spaces must match.

  • B (ComputationalOperator) – Instances of ComputationalOperator. Their Hilbert spaces must match.

  • subtract (bool, optional) – If True, computes A - B instead of A + B (default: False).

  • tol (float, optional) – Numerical threshold below which connections are treated as zero (default: 0.0).

  • is_hermitian (bool, optional) – Whether to mark this composite operator as Hermitian (default: False).

Notes

  • All nonzero connections from A and B are merged and returned in padded format

  • Overlapping connections (identical connected states) are not deduplicated for speed, downstream contraction routines will handle that if needed

class SumJax(A, B, *, subtract=False, tol=0.0, is_hermitian=False)

Bases: ComputationalJaxOperator

JAX-native, matrix-free sum or difference of two ComputationalJaxOperators A and B:

\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``}, \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``}. \end{cases}\end{split}\]
Parameters:
  • A (ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.

  • B (ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.

  • subtract (bool, optional) – If True, computes A - B instead of A + B (default: False).

  • tol (float, optional) – Threshold for treating matrix elements as zero (default 0.0).

  • is_hermitian (bool, optional) – Whether to mark this composite operator as Hermitian (default: False).

Notes

  • Uses the padded connection API (NetKet contract): the connection axis has a fixed size per call

  • No Python-side loops, pure JAX operations and pytree registration

  • Works with both single-input (N, D) and batched (B, N, D) configurations

class Scaled(op, alpha)

Bases: ComputationalOperator

Wrapper for scalar multiplication of a ComputationalOperator.

Given an operator O and a scalar α (float/complex, Python/jnp/np), this represents the operator α·O:

(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.

Notes

  • Scales all matrix elements (diagonal and off-diagonal).

  • Dtype is promoted with promote_constant_for_op_dtype(op.dtype, α) and then combined via jnp.result_type.

  • Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.

class ScaledJax(op, alpha)

Bases: ComputationalJaxOperator

JAX-native wrapper for scalar multiplication of a ComputationalJaxOperator.

Given an operator O and a scalar α (float/complex, Python/jnp/np), this represents the operator α·O:

(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.

Notes

  • Scales all matrix elements (diagonal and off-diagonal).

  • Dtype is promoted with promote_constant_for_op_dtype(op.dtype, α) and then combined via jnp.result_type.

  • Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.

Subpackages