neuralqx.operators.computational.wrappers.numba package

class Product(A, B, *, tol=0.0, is_hermitian=False)

Bases: ComputationalOperator

Matrix-free product of two local ComputationalOperators A (right) and B (left):

\[(A @ B)\,|\sigma\rangle \;=\; A\big( B\,|\sigma\rangle \big).\]

This class does not form dense matrices. Instead, it composes the connection graphs provided by A and B. We:

  1. Apply B once to the input ket(s) to obtain all connected states and matrix elements

  2. For each kept branch produced by B, apply A to those branch states

  3. Multiply the path matrix elements and return the final connected states

To mitigate potential memory blow-up from too many connections, the implementation:

  • Early-prunes whole branches from B that are numerically zero for the entire batch: a branch index cb is kept iff any_{b,n} (abs(mB[b,n,cb]) > tol). This avoids ever calling A on completely useless branches

  • Stable-partitions the final connections so that all nonzeros come first along the connection axis, and only the tail is padded with (σ, 0.0) to preserve NetKet’s padded contract (fixed shapes)

  • Works for both a single input ket of shape (N, D) and a batch (B, N, D). The single-input path simply promotes to (1, N, D) and drops the batch dimension at the end

Parameters:
  • A (ComputationalOperator) – Instances of ComputationalOperator. A acts on the left, B on the right (i.e., this class represents A @ B). Their Hilbert spaces must match

  • B (ComputationalOperator) – Instances of ComputationalOperator. A acts on the left, B on the right (i.e., this class represents A @ B). Their Hilbert spaces must match

  • tol (float, optional) – Numerical threshold (default: 0.0). Any connection with |mel| <= tol is treated as zero for the purposes of pruning and partitioning

Notes

Padded API contract.

NetKet’s operator application returns tensors with a fixed connection axis length C for a whole call (and a whole batch). If different items/samples would produce a different number of connections, we must still return a uniform shape, therefore we:

  • move all nonzeros to the front,

  • pad the tail with the input configurations and zero matrix elements

In principle downstream code should only consume the leading valid connections. You can compute the per-sample valid counts on the fly via:

counts = (jnp.abs(mels) > tol).sum(axis=2)  # (B, N)

and use only mels[b, n, :counts[b,n]] and σp[b, n, :counts[b,n], :].

Why batch shapes differ from single:

In batch mode we keep the union of B-branches that are needed by any example in the batch. Therefore the connection axis length C can be larger for larger batches. For examples that do not require some of those branches, the corresponding slots are in the padded tail (zeros).

Hermiticity:

Even if A and B are Hermitian individually, the product A @ B need not be Hermitian. Therefore, we default to Hermiticity being False, but this can be set at initialisation as well.

Examples

Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:

A  = ThiemannRegularisedVertexConstraint(lqx.model, vertex, apply_lapse=True, adjoint=False)
Ad = ThiemannRegularisedVertexConstraint(lqx.model, vertex, apply_lapse=True, adjoint=True)
HvHvdag = Product(A, Ad, tol=1e-12)

# single input ket (N, D):
sigma_p, mels = HvHvdag.get_conn_padded(sigma)     # (N, C, D), (N, C)
counts = (jnp.abs(mels) > 1e-12).sum(axis=1)       # (N,)

# batched (B, N, D):
sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma)  # (B, N, C, D), (B, N, C)
counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2)          # (B, N)

Complexity

Let C_B be the number of connections emitted by B and C_A those of A. After pruning keeps K_B <= C_B branches, the total connections returned are approximately C = K_B * C_A per input chain position. Memory scales with C.

class Sum(A, B, *, subtract=False, tol=0.0, is_hermitian=False)

Bases: ComputationalOperator

Matrix-free sum or difference of two local ComputationalOperators A and B:

\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``} \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``} \end{cases}\end{split}\]

This class does not construct dense matrices, it merges the connection graphs of A and B and sums (or subtracts) their matrix elements. Padded output follows NetKet’s fixed-shape contract.

Parameters:
  • A (ComputationalOperator) – Instances of ComputationalOperator. Their Hilbert spaces must match.

  • B (ComputationalOperator) – Instances of ComputationalOperator. Their Hilbert spaces must match.

  • subtract (bool, optional) – If True, computes A - B instead of A + B (default: False).

  • tol (float, optional) – Numerical threshold below which connections are treated as zero (default: 0.0).

  • is_hermitian (bool, optional) – Whether to mark this composite operator as Hermitian (default: False).

Notes

  • All nonzero connections from A and B are merged and returned in padded format

  • Overlapping connections (identical connected states) are not deduplicated for speed, downstream contraction routines will handle that if needed

class Scaled(op, alpha)

Bases: ComputationalOperator

Wrapper for scalar multiplication of a ComputationalOperator.

Given an operator O and a scalar α (float/complex, Python/jnp/np), this represents the operator α·O:

(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.

Notes

  • Scales all matrix elements (diagonal and off-diagonal).

  • Dtype is promoted with promote_constant_for_op_dtype(op.dtype, α) and then combined via jnp.result_type.

  • Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.