neuralqx.operators.computational.wrappers package¶
- class Product(A, B, *, tol=0.0, is_hermitian=False)¶
Bases:
ComputationalOperatorMatrix-free product of two local ComputationalOperators
A(right) andB(left):\[(A @ B)\,|\sigma\rangle \;=\; A\big( B\,|\sigma\rangle \big).\]This class does not form dense matrices. Instead, it composes the connection graphs provided by
AandB. We:Apply
Bonce to the input ket(s) to obtain all connected states and matrix elementsFor each kept branch produced by
B, applyAto those branch statesMultiply the path matrix elements and return the final connected states
To mitigate potential memory blow-up from too many connections, the implementation:
Early-prunes whole branches from
Bthat are numerically zero for the entire batch: a branch indexcbis kept iffany_{b,n} (abs(mB[b,n,cb]) > tol). This avoids ever callingAon completely useless branchesStable-partitions the final connections so that all nonzeros come first along the connection axis, and only the tail is padded with
(σ, 0.0)to preserve NetKet’s padded contract (fixed shapes)Works for both a single input ket of shape
(N, D)and a batch(B, N, D). The single-input path simply promotes to(1, N, D)and drops the batch dimension at the end
- Parameters:
A (
ComputationalOperator) – Instances ofComputationalOperator.Aacts on the left,Bon the right (i.e., this class representsA @ B). Their Hilbert spaces must matchB (
ComputationalOperator) – Instances ofComputationalOperator.Aacts on the left,Bon the right (i.e., this class representsA @ B). Their Hilbert spaces must matchtol (
float, optional) – Numerical threshold (default:0.0). Any connection with|mel| <= tolis treated as zero for the purposes of pruning and partitioning
Notes
- Padded API contract.
NetKet’s operator application returns tensors with a fixed connection axis length
Cfor a whole call (and a whole batch). If different items/samples would produce a different number of connections, we must still return a uniform shape, therefore we:move all nonzeros to the front,
pad the tail with the input configurations and zero matrix elements
In principle downstream code should only consume the leading valid connections. You can compute the per-sample valid counts on the fly via:
counts = (jnp.abs(mels) > tol).sum(axis=2) # (B, N)and use only
mels[b, n, :counts[b,n]]andσp[b, n, :counts[b,n], :].- Why batch shapes differ from single:
In batch mode we keep the union of B-branches that are needed by any example in the batch. Therefore the connection axis length
Ccan be larger for larger batches. For examples that do not require some of those branches, the corresponding slots are in the padded tail (zeros).- Hermiticity:
Even if
AandBare Hermitian individually, the productA @ Bneed not be Hermitian. Therefore, we default to Hermiticity being False, but this can be set at initialisation as well.
Examples
Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:
A = ThiemannRegularisedVertexConstraint(lqx.model, vertex, apply_lapse=True, adjoint=False) Ad = ThiemannRegularisedVertexConstraint(lqx.model, vertex, apply_lapse=True, adjoint=True) HvHvdag = Product(A, Ad, tol=1e-12) # single input ket (N, D): sigma_p, mels = HvHvdag.get_conn_padded(sigma) # (N, C, D), (N, C) counts = (jnp.abs(mels) > 1e-12).sum(axis=1) # (N,) # batched (B, N, D): sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma) # (B, N, C, D), (B, N, C) counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2) # (B, N)
Complexity¶
Let
C_Bbe the number of connections emitted byBandC_Athose ofA. After pruning keepsK_B <= C_Bbranches, the total connections returned are approximatelyC = K_B * C_Aper input chain position. Memory scales withC.
- class ProductJax(A, B, *, tol=0.0, is_hermitian=False)¶
Bases:
ComputationalJaxOperatorJAX-native, matrix-free product of two padded local operators
A(left) andB(right):\[(A @ B)\,|\sigma\rangle \;=\; A\!\left( B\,|\sigma\rangle \right).\]The class composes the connection graphs without building dense matrices. What it does is:
Apply
Bonce to the input(s): obtain(σ'_B, m_B)of shape(B, N, C_B, D)and(B, N, C_B)For each branch index
cbofB: - If that branch is globally zero over the batch (\max_{b,n}|m_B[b,n,cb]| \le tol),skip it (we will inject zeros in the padded tail).
Else apply
Atoσ'_B[..., cb, :]and combine path matrix elements:m_{path} = m_B[..., cb, None] * m_A.
Concatenate across kept branches, stable-partition so that all nonzero connections come first, and pad the tail with the input state(s) and zeros
- Parameters:
A (
ComputationalJaxOperator) – JAX-native operators to be composed asA @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.B (
ComputationalJaxOperator) – JAX-native operators to be composed asA @ B. They must share the same Hilbert space. Their internal implementations may themselves use jitted kernels and PyTree registration.tol (
float, optional) – Threshold for treating matrix elements as zero (default0.0). This affects pruning and the stable partitioning (which connections are considered “valid”)is_hermitian (
bool, optional) – If you knowA @ Bis Hermitian (e.g.,A=H,B=HorB=A†and the product is provably Hermitian), set this toTrueto select covariance gradients downstream
Notes
Padded contract: The connection axis has a fixed size per call. All nonzero connections are stably moved to the front. The tail contains copies of the input state with zero matrix element. Downstream Monte-Carlo code should only consume the first
counts[b,n] = (|mels[b,n,:]|>tol).sum()connectionsBatch behavior: Pruning is done with a global keep mask over the batch to avoid data-dependent recompiles, larger batches may retain more branches (others are padded)
JAX transforms: The heavy lifting (packing/partitioning) is a single, stateless
jax.jit()kernel taking only arrays, there are no Python-side loops or dynamic slices inside jit
Compose a Euclidean Thiemann regularised constraint for the 4D WCL model with its adjoint:
A = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=False) Ad = ThiemannRegularisedVertexConstraintJax(lqx.model, vertex, apply_lapse=True, adjoint=True) HvHvdag = ProductJax(A, Ad, tol=1e-12) # single input ket (N, D): sigma_p, mels = HvHvdag.get_conn_padded(sigma) # (N, C, D), (N, C) counts = (jnp.abs(mels) > 1e-12).sum(axis=1) # (N,) # batched (B, N, D): sigma_p_b, mels_b = HvHvdag.get_conn_padded(batch_sigma) # (B, N, C, D), (B, N, C) counts_b = (jnp.abs(mels_b) > 1e-12).sum(axis=2) # (B, N)
- class Sum(A, B, *, subtract=False, tol=0.0, is_hermitian=False)¶
Bases:
ComputationalOperatorMatrix-free sum or difference of two local ComputationalOperators
AandB:\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``} \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``} \end{cases}\end{split}\]This class does not construct dense matrices, it merges the connection graphs of
AandBand sums (or subtracts) their matrix elements. Padded output follows NetKet’s fixed-shape contract.- Parameters:
A (
ComputationalOperator) – Instances ofComputationalOperator. Their Hilbert spaces must match.B (
ComputationalOperator) – Instances ofComputationalOperator. Their Hilbert spaces must match.subtract (
bool, optional) – IfTrue, computesA - Binstead ofA + B(default:False).tol (
float, optional) – Numerical threshold below which connections are treated as zero (default:0.0).is_hermitian (
bool, optional) – Whether to mark this composite operator as Hermitian (default:False).
Notes
All nonzero connections from
AandBare merged and returned in padded formatOverlapping connections (identical connected states) are not deduplicated for speed, downstream contraction routines will handle that if needed
- class SumJax(A, B, *, subtract=False, tol=0.0, is_hermitian=False)¶
Bases:
ComputationalJaxOperatorJAX-native, matrix-free sum or difference of two ComputationalJaxOperators
AandB:\[\begin{split}S\,|\sigma\rangle = \begin{cases} (A + B)\,|\sigma\rangle, & \text{if ``subtract=False``}, \\ (A - B)\,|\sigma\rangle, & \text{if ``subtract=True``}. \end{cases}\end{split}\]- Parameters:
A (
ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.B (
ComputationalJaxOperator) – JAX-native local operators to be combined. They must share the same Hilbert space.subtract (
bool, optional) – IfTrue, computesA - Binstead ofA + B(default:False).tol (
float, optional) – Threshold for treating matrix elements as zero (default0.0).is_hermitian (
bool, optional) – Whether to mark this composite operator as Hermitian (default:False).
Notes
Uses the padded connection API (NetKet contract): the connection axis has a fixed size per call
No Python-side loops, pure JAX operations and pytree registration
Works with both single-input
(N, D)and batched(B, N, D)configurations
- class Scaled(op, alpha)¶
Bases:
ComputationalOperatorWrapper for scalar multiplication of a ComputationalOperator.
Given an operator
Oand a scalarα(float/complex, Python/jnp/np), this represents the operatorα·O:(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.
Notes
Scales all matrix elements (diagonal and off-diagonal).
Dtype is promoted with
promote_constant_for_op_dtype(op.dtype, α)and then combined viajnp.result_type.Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.
- class ScaledJax(op, alpha)¶
Bases:
ComputationalJaxOperatorJAX-native wrapper for scalar multiplication of a ComputationalJaxOperator.
Given an operator
Oand a scalarα(float/complex, Python/jnp/np), this represents the operatorα·O:(α·O) |σ⟩ -> same connected states as O, matrix elements scaled by α.
Notes
Scales all matrix elements (diagonal and off-diagonal).
Dtype is promoted with
promote_constant_for_op_dtype(op.dtype, α)and then combined viajnp.result_type.Hermiticity: marked Hermitian iff the base operator is Hermitian and α is (numerically) real.