neuralqx.experimental.operators.symbolic package

Symbolic operator subsystem for neuraLQX.

Typical workflow:

from neuralqx.experimental.operators.symbolic import DOperator
from neuralqx.experimental.operators.symbolic.dsl import site, shift, swap

hop = (
    DOperator(hi, "hopping")
    .for_each_pair("i", "j")
    .where(site("i") > 0)
    .emit(shift("i", -1).shift("j", +1), amplitude=1.0)
    .build()
)
compiled = hop.compile()
xp, mels = compiled.get_conn_padded(x_batch)
class DOperator(hilbert, name='operator', *, dtype='float64', hermitian=False)

Bases: object

Fluent builder for declarative symbolic quantum operators.

The builder accumulates one or more terms. Each term consists of an iterator (which sites to visit), an optional predicate (which visits to activate), and one or more emissions (how to rewrite the configuration and what matrix element to assign per active visit).

Calling any iterator method (for_each_site, for_each_pair, …, globally) seals the previous term (if any) and begins a new one. .where and .emit always target the current open term.

Parameters:
  • hilbert (DiscreteHilbert) – NetKet DiscreteHilbert space.

  • name (str) – Readable operator name (accessible as .name on the resulting SymbolicOperator).

  • dtype (str) – Matrix-element dtype string (default "float64").

  • hermitian (bool) – Whether to declare the operator Hermitian.

class Update(_program=None)

Bases: object

Immutable, chainable site-update program builder.

Every instance method appends one operation and returns a new Update object, the original is never mutated. The canonical entry points are the module-level free functions (shift(), write(), swap(), permute(), affine(), scatter(), identity()) which avoid the Update() boilerplate.

Parameters:

_program (UpdateProgram | None) – Internal update program (do not pass manually).

shift(site_ref, delta)

Appends x'[i] = x[i] + delta.

Parameters:
  • site_ref (str | SiteSelector | int | AmplitudeExpr) – Target site (label string, selector, or flat index).

  • delta (Any) – Shift amount, numeric or amplitude expression.

Return type:

Update

Returns:

New Update with this operation appended.

shift_mod(site_ref, delta)

Appends a Hilbert-aware wrapped shift.

Semantics are resolved from the enclosing operator’s Hilbert space at build/compile time. For now this requires contiguous unit-spaced integer local_states such as [-m_max, …, m_max].

Resulting runtime semantics:

x’[i] = ((x[i] + delta - state_min) % mod_span) + state_min

Return type:

Update

write(site_ref, value)

Appends x'[i] = value.

Parameters:
Return type:

Update

Returns:

New Update with this operation appended.

swap(site_a, site_b)

Appends x'[a], x'[b] = x[b], x[a].

Parameters:
Return type:

Update

Returns:

New Update with this operation appended.

permute(*site_refs)

Appends a cyclic rotation over K sites.

After the operation:

x'[s0] ← x[s1],   x'[s1] ← x[s2],   ...,   x'[sK-1] ← x[s0]

All K source values are captured from the current x' state before any writes are applied, so the rotation is atomic.

Parameters:

*site_refs (str | SiteSelector | int | AmplitudeExpr) – Two or more site references in rotation order.

Return type:

Update

Returns:

New Update with this operation appended.

Raises:

ValueError – If fewer than 2 site references are provided.

affine(site_ref, *, scale, bias=0)

Appends x'[i] = scale * x[i] + bias.

Parameters:
  • site_ref (str | SiteSelector | int | AmplitudeExpr) – Target site.

  • scale (Any) – Multiplicative scale, numeric or amplitude expression.

  • bias (Any) – Additive bias, numeric or amplitude expression (default 0).

Return type:

Update

Returns:

New Update with this operation appended.

scatter(flat_indices, values)

Appends bulk writes to static flat site indices.

For each (flat_index, value) pair:

x'[flat_index] = value

Indices must be compile-time-constant integers (baked into the IR). Values may be arbitrary amplitude expressions.

Parameters:
  • flat_indices (list[int] | tuple[int, ...]) – Sequence of static integer site indices.

  • values (list[Any] | tuple[Any, ...]) – Sequence of amplitude expressions (or coercible values).

Return type:

Update

Returns:

New Update with this operation appended.

Raises:

ValueError – If flat_indices and values have different lengths.

shift(site_ref, delta)

Returns an Update that shifts site site_ref by delta.

Example:

shift("i", +1)  # raise site i by 1
shift(0, -1)  # lower flat site 0 by 1
shift("j", site("i").value)  # shift j by x[i]
Return type:

Update

shift_mod(site_ref, delta)

Returns an Update performing a Hilbert-aware wrapped modular shift.

Example:

shift_mod("i", +1)
shift_mod(0, -2)
Return type:

Update

write(site_ref, value)

Returns an Update that writes value to site site_ref.

Example:

write("i", 0)  # zero site i
write(5, site("j").value)  # copy x[j] into flat site 5
Return type:

Update

swap(site_a, site_b)

Returns an Update that swaps sites site_a and site_b.

Example:

swap("i", "j")  # exchange x[i] and x[j]
swap(0, 10)  # exchange flat sites 0 and 10
Return type:

Update

permute(*site_refs)

Returns an Update performing a cyclic rotation over K sites.

Example:

permute("i", "j", "k")  # x'[i]←x[j], x'[j]←x[k], x'[k]←x[i]
permute(0, 5, 10)  # same with flat indices
Return type:

Update

affine(site_ref, *, scale, bias=0)

Returns an Update computing x'[i] = scale * x[i] + bias.

Example:

affine("i", scale=2, bias=-1)  # x'[i] = 2*x[i] - 1
affine(0, scale=-1, bias=0)  # negate flat site 0
Return type:

Update

scatter(flat_indices, values)

Returns an Update performing bulk writes to static flat indices.

Example:

scatter([0, 10, 20], [1, -1, 0])  # write constant values
scatter([0, 10], [site("i").value, 0])  # mixed expr / constant
Return type:

Update

identity()

Returns the identity (no-op) Update.

Use for diagonal operators where x' = x:

DOperator(hi, "diagonal").globally().emit(identity(), amplitude=my_expr)
Return type:

Update

site(label)

Returns a symbolic site selector.

Parameters:

label (str) – Iterator label bound by for_each_site(label) or for_each_pair(label_a, label_b).

Return type:

SiteSelector

Returns:

Site selector handle.

Example

from neuralqx.experimental.operators.symbolic.dsl import site

s = site("i")
print(s.value)          # AmplitudeExpr, x[i]
print(s.index)          # AmplitudeExpr, i
print(s.value < 3)      # PredicateExpr, x[i] < 3
print(s.value + 1)      # AmplitudeExpr, x[i] + 1
emitted(label)

Returns a symbolic selector bound to the emitted or connected state x'.

Return type:

SiteSelector

Example

from neuralqx.experimental.operators.symbolic.dsl import emitted

e = emitted("i")
print(e.value)   # AmplitudeExpr, x'[i]
print(e.index)   # AmplitudeExpr, i
symbol(name)

Returns a free symbolic amplitude expression by name.

Free symbols are not bound to any site iterator, they are resolved at operator-evaluation time from external parameter dictionaries.

Parameters:

name (str) – Symbol name.

Return type:

AmplitudeExpr

Returns:

Symbolic amplitude expression.

class SiteSelector(label, namespace='site')

Bases: object

Symbolic selector for one Hilbert-space site iterator.

SiteSelector is created by site() and used inside DSL predicates, amplitude rules, and update programs. Attribute access on a selector returns symbolic AmplitudeExpr nodes that are resolved by the compiler at lowering time.

Parameters:

label (str) – Iterator label bound by for_each_site(label).

class SymbolicOperator(hilbert, name, ir_terms, *, dtype_str='complex64', is_hermitian=False, metadata=None)

Bases: AbstractSymbolicOperator

A symbolic operator built via the DOperator DSL.

SymbolicOperator is the canonical result of DOperator(...).build(). It holds an ordered list of typed IR terms and provides a .compile() method to lower them to an executable CompiledOperator.

Instances are not directly executable: calling get_conn_padded before compilation raises SymbolicOperatorExecutionError.

name

User-facing operator name.

hilbert

The NetKet Hilbert space.

dtype

Matrix-element dtype.

is_hermitian

Whether this operator is declared Hermitian.

Example:

op = (
    DOperator(hi, "hopping")
    .for_each_pair("i", "j")
    .where(site("i") > 0)
    .emit(shift("i", -1).shift("j", +1), amplitude=1.0)
    .build()
)
compiled = op.compile()
xp, mels = compiled.get_conn_padded(x_batch)
class CompiledOperator(hilbert, *, name, fn, is_hermitian, dtype)

Bases: ComputationalJaxOperator

An executable operator produced by lowering a SymbolicOperator.

CompiledOperator is the concrete result of symbolic_op.compile() or DOperator(...).compile(). Its get_conn_padded kernel is a pure JAX function that can be JIT-compiled, vmapped, and differentiated.

The class name is fixed and stable, it does not encode the operator name or structure. The readable operator name is available via the name property.

name

Operator name (from the DSL definition).

is_hermitian

Whether this operator is declared Hermitian.

dtype

Matrix-element NumPy dtype.

class AbstractSymbolicOperator(hilbert, *, name, dtype_str='complex64', is_hermitian=False, metadata=None)

Bases: ComputationalJaxOperator

Abstract base class for all symbolic (DSL-defined) operators.

Symbolic operators extend ComputationalJaxOperator and declare their action through a typed IR rather than a hand-written JAX kernel. They cannot execute until the compiler has lowered them to a concrete JAX kernel via neuralqx.experimental.operators.symbolic.compiler.SymbolicCompiler.compile().

Attempting to call _get_conn_padded() before compilation raises SymbolicOperatorExecutionError.

Parameters:
  • hilbert (DiscreteHilbert) – Discrete Hilbert space this operator is defined on.

  • name (str) – User-facing operator name.

  • dtype_str (str) – String label for the matrix-element dtype.

  • is_hermitian (bool) – Whether this operator is declared Hermitian.

  • metadata (dict[str, Any] | None) – Optional extra metadata dictionary.

class SymbolicOperatorSum(hilbert, terms, *, name=None, dtype_str=None, is_hermitian=None, metadata=None)

Bases: AbstractSymbolicOperator

Additive composition of multiple symbolic operators sharing one Hilbert space.

SymbolicOperatorSum is the canonical Hamiltonian-style container for DSL-defined operators. It preserves term ordering, flattens nested sums, and aggregates fanout bounds across all contained terms.

Parameters:
  • hilbert (DiscreteHilbert) – Shared Hilbert space.

  • terms (Sequence[AbstractSymbolicOperator]) – Sequence of symbolic operator terms.

  • name (str | None) – Optional user-facing operator name.

  • dtype_str (str | None) – Optional explicit dtype override.

  • is_hermitian (bool | None) – Optional Hermiticity override (defaults to True iff all contained terms are Hermitian).

  • metadata (dict[str, Any] | None) – Optional metadata dictionary.

class ExpressionContext(*args, **kwargs)

Bases: object

Utility context passed to DSL callables at build time.

This context only builds IR expression nodes and never captures Python-runtime callbacks.

Example

>>> def my_amplitude(ctx):
...     i = ctx.site("i")
...     return ctx.sqrt(i.value + 1)
symbol(name)

Returns a free symbolic amplitude expression.

Return type:

AmplitudeExpr

site(label)

Returns a site selector by label.

Return type:

SiteSelector

emitted(label)

Returns an emitted-state selector by label.

Return type:

SiteSelector

class SymbolicCompiler(*, pipeline=None, lowerer_registry=None, artifact_store=None, options=None)

Bases: object

Orchestrates the symbolic operator compilation pipeline.

The compiler accepts a symbolic operator (an AbstractSymbolicOperator), runs it through the registered pass pipeline, optionally resolves a cache hit, and, on a miss, invokes the appropriate lowerer to produce a concrete ComputationalJaxOperator.

Typical usage:

from neuralqx.experimental.operators.symbolic import SymbolicCompiler

compiler = SymbolicCompiler()
compiled_op = compiler.compile_operator(my_symbolic_op)
xp, mels = compiled_op.get_conn_padded(x_batch)
Parameters:
compile_symbolic_operator(operator, *, options=None, metadata=None)

Module-level convenience function for one-shot symbolic compilation.

Uses the module-level shared SymbolicCompiler instance (lazily created). The shared compiler reuses the global in-process artifact cache.

Parameters:
Return type:

ComputationalJaxOperator

Returns:

Executable ComputationalJaxOperator.

Example:

from neuralqx.experimental.operators.symbolic import compile_symbolic_operator

compiled_op = compile_symbolic_operator(my_symbolic_op)
xp, mels = compiled_op.get_conn_padded(x_batch)
class SymbolicCompilerOptions(backend_preference='auto', enable_fusion=True, strict_validation=True, cache_enabled=True, cache_namespace='nqx_symbolic_v1', debug_flags=<factory>)

Bases: object

Static and runtime controls for symbolic compiler execution.

backend_preference

Preferred lowering backend (currently only jax is supported, auto resolves to jax).

enable_fusion

Whether fusion-planning passes are enabled.

strict_validation

Whether validation passes fail hard on errors.

cache_enabled

Whether compiled artifacts are cached in-process.

cache_namespace

Namespace string used in cache-key generation.

debug_flags

Optional debug / instrumentation flags.

class SymbolicCompiledArtifact(operator_name, backend, lowerer_name, compiled_operator, cache_key=None, pass_reports=<factory>, metadata=<factory>)

Bases: object

Compilation artifact produced by the symbolic compiler pipeline.

operator_name

Source operator name.

backend

Selected backend name.

lowerer_name

Lowerer identifier used for code generation.

compiled_operator

Executable compiled operator object.

cache_key

Optional compilation cache key.

pass_reports

Ordered tuple of pass-execution reports.

metadata

Optional artifact metadata.

class SymbolicCompilationContext(*, operator, ir, options, metadata=None)

Bases: object

Holds per-compilation mutable state across pipeline stages.

The context is created by the compiler, mutated in-place by passes and lowerers, and finally read when packaging the compiled artifact.

Parameters:
class SymbolicCacheKey(token, namespace)

Bases: object

Immutable cache key for compiled symbolic operator artifacts.

class SymbolicCompilationSignature(operator_ir_fingerprint, backend_target, hilbert_size, dtype_str, options_signature=<factory>)

Bases: object

Deterministic compilation signature for cache-key generation.

operator_ir_fingerprint

Stable digest of the operator IR.

backend_target

Resolved backend name.

hilbert_size

Hilbert space size.

dtype_str

Matrix-element dtype string.

options_signature

Static compiler-options signature.

default_symbolic_pass_pipeline()

Builds the default two-stage symbolic compiler pass pipeline.

Pre-cache passes (run on every compile() call):
  1. SymbolicValidationPass - validates IR symbol scopes and update-op parameters.

  2. SymbolicNormalizationPass - computes the IR fingerprint and resolves the target backend.

Post-cache passes (run only on cache misses):
  1. SymbolicFanoutAnalysisPass - derives per-term fanout bounds and the total padded output size.

  2. SymbolicFusionPass - groups terms into fusion-compatible clusters for the lowerer.

Return type:

SymbolicPassPipeline

Returns:

Configured SymbolicPassPipeline.

default_symbolic_lowerer_registry()

Builds the default symbolic lowerer registry.

Currently registers only the JAX backend lowerer (JAXSymbolicLowerer).

Return type:

SymbolicLowererRegistry

Returns:

Configured SymbolicLowererRegistry.

default_symbolic_artifact_store()

Returns the module-level shared in-memory artifact store.

The store is lazily created and reused across compiler instances in the same process. Call InMemorySymbolicArtifactStore.clear() to evict all compiled artifacts if needed.

Return type:

InMemorySymbolicArtifactStore

Returns:

Shared InMemorySymbolicArtifactStore.

class SymbolicOperatorIR(operator_name, mode, hilbert_size, dtype_str, is_hermitian, terms=<factory>, metadata=<factory>)

Bases: object

Immutable symbolic operator IR container.

operator_name

Name of the operator this IR represents.

mode

IR mode (symbolic for DSL-built operators, jax_kernel for direct JAX-kernel operators).

hilbert_size

Size of the Hilbert space (number of sites).

dtype_str

String representation of the matrix-element dtype.

is_hermitian

Whether the source operator is declared Hermitian.

terms

Declarative term tuple for symbolic mode.

metadata

Optional stable metadata tuple.

class SymbolicIRTerm(name, iterator, predicate, update_program, amplitude, branch_tag=None, metadata=<factory>, fanout_hint=None, emissions=None)

Bases: object

One primitive declarative symbolic operator term.

name

Term name.

iterator

Iterator descriptor (KBodyIteratorSpec).

predicate

Branch-selection predicate.

update_program

Matrix-element update program.

amplitude

Matrix-element expression.

branch_tag

Optional branch tag for diagnostics.

metadata

Optional stable term metadata tuple.

fanout_hint

Optional static upper-bound hint on the number of connected states this term produces per input configuration.

emissions

Optional multi-emission tuple that, when present, supersedes update_program and amplitude. Each entry is an EmissionSpec representing one output branch per iterator evaluation.

class EmissionSpec(update_program, amplitude, branch_tag=None)

Bases: object

One output branch (connected state + matrix element) of a term.

A single iterator evaluation can produce multiple branches, one per EmissionSpec in the parent term’s emissions tuple. This allows a plaquette term, for example, to emit both + and - connected states from the same site-tuple without splitting into two separate terms.

update_program

Site-update program mapping x -> x'.

amplitude

Matrix-element expression evaluated in the source environment.

branch_tag

Optional diagnostic tag for this emission slot.

class KBodyIteratorSpec(labels, index_sets)

Bases: object

Static K-body iterator over a pre-computed list of site-index tuples.

This iterator evaluates a term kernel once per entry in index_sets. Each entry is a K-tuple of integer site indices that are bound to the corresponding element of labels inside the evaluation environment.

For a single-site iterator over all N sites, use KBodyIteratorSpec(labels=("i",), index_sets=tuple((k,) for k in range(N))). For a static triplet iterator, provide the explicit list of (e1, e2, e3) triplets. For a global (one-branch) term, use KBodyIteratorSpec(labels=(), index_sets=((),)).

labels

Ordered tuple of K label strings bound per iteration.

index_sets

M-tuple of K-tuples of int site indices.

class AmplitudeExpr(op, args=<factory>)

Bases: object

Typed expression node for operator matrix elements.

op

Expression operation name.

args

Ordered operation arguments (frozen tuple).

classmethod symbol(name)

Builds a symbol-reference expression node.

Return type:

AmplitudeExpr

class PredicateExpr(op, args=<factory>)

Bases: object

Typed boolean expression node for operator branch filtering.

op

Predicate operation name.

args

Ordered operation arguments.

class UpdateProgram(ops=<factory>)

Bases: object

Ordered immutable sequence of site-update operations.

ops

Ordered update-operation tuple.

class UpdateOp(kind, params=<factory>)

Bases: object

One primitive site-update operation.

kind

Update operation kind (see _UPDATE_OP_KINDS).

params

Deterministic parameter tuple ((key, value), ...). Values are AmplitudeExpr nodes, plain integers, or nested structures depending on kind.

coerce_amplitude_expr(value)

Coerces user values into typed amplitude-expression nodes.

Parameters:

value (Any) – Input expression value.

Return type:

AmplitudeExpr

Returns:

Typed amplitude expression.

Raises:

TypeError – If value cannot be converted.

coerce_predicate_expr(value)

Coerces user values into typed predicate-expression nodes.

Parameters:

value (Any) – Input predicate value.

Return type:

PredicateExpr

Returns:

Typed predicate expression.

Raises:

TypeError – If value cannot be converted.

Subpackages