neuralqx.experimental.utils.struct.base module¶
Public struct base classes and decorators.
- This module defines:
StructMeta: metaclass that processes classes into immutable pytrees.Struct: ergonomic user-facing base class.decorator helpers for registering existing classes as structs.
This module is the primary user entry point for defining immutable, traceable state containers that behave well under JAX transformations and neuralQX serialisation flows.
- Compared to plain dataclasses, Struct adds:
explicit node/static/opaque field semantics,
deterministic freeze/validation lifecycle,
built-in pytree registration,
integrated state export/load methods.
Example
import neuralqx as nqx
s = nqx.utils.struct
class State(s.Struct):
params: object
name: str = s.field(static=True, default="exp")
st = State(params={"w": [1, 2, 3]})
st2 = st.replace(name="exp-2")
- class Struct¶
Bases:
objectImmutable, JAX-native container with explicit field semantics.
Structis designed as a forward-looking abstraction for differentiable systems code:deterministic construction and validation,
strict post-init immutability,
first-class pytree registration,
stable serialization contracts.
Field behaviour is driven by
neuralqx.utils.struct.field().- For generated constructors and reconstruction paths, the lifecycle is:
assign user/default values to non-derived fields
compute derived fields
run
__post_init__(if defined)recompute derived fields (post-init may have changed dependencies)
validate values and static constraints
freeze instance (further mutation raises
FrozenStructError)
Struct instances are logically (almost) immutable after initialisation. Use
replace()to derive updated copies.- Pytree contract:
Node fields are leaves.
Static fields are auxiliary data.
Opaque fields are auxiliary identity-preserving references.
- Serialisation contract:
- Struct exposes convenience wrappers over
neuralqx.utils.struct.io: to_state_dict()from_state_dict()export()load()
- Struct exposes convenience wrappers over
- classmethod fields()¶
Return ordered mapping of field names to
FieldSpec.This is the canonical schema view for the class. The mapping order matches declaration/MRO resolution order used by constructor generation and derived-field evaluation.
- classmethod node_fields()¶
Return names of fields traced as pytree leaves.
These fields are visible to
jax.tree_util.tree_leavesand therefore participate in transformations such asjit,grad, andvmap.
- classmethod static_fields()¶
Return names of static pytree metadata fields.
Static fields are part of pytree aux data and influence JAX cache keys. They must remain hashable and array-free.
- classmethod opaque_fields()¶
Return names of opaque runtime fields excluded from pytree leaves.
Opaque fields are useful for runtime-only objects that should not be traced or serialized by default (for example handles or ephemeral tokens).
- class StructMeta(name, bases, namespace, **kwargs)¶
Bases:
typeMetaclass that finalises classes into Struct-compatible pytrees.
- Every subclass of
Structis processed at class creation time: field definitions are collected and normalised,
constructor signature is synthesised if needed,
JAX pytree hooks are registered,
class reference is cached for serialisation.
The metaclass guarantees that class-level invariants are established exactly once when classes are defined, not at first instance construction. This provides:
predictable failure modes for invalid field declarations,
zero per-instance schema analysis overhead,
consistent runtime behaviour regardless of object creation path.
- Every subclass of
- class StructABCMeta(name, bases, namespace, **kwargs)¶
Bases:
StructMeta,ABCMetaCombined metaclass giving both Struct processing and ABC enforcement.
- dataclass(cls=None, *, name=None)¶
Dataclass-style alias for
register_class().This helper exists for readability in codebases that conceptually treat Struct definitions as immutable dataclasses with JAX semantics.
- register_class(cls=None, *, name=None)¶
Register an existing class as a
Structtype.Supports both styles:
@register_class class Graph: nodes: ... @register_class(name="RenamedGraph") class Graph: ...
- Parameters:
- Return type:
- Returns:
Registered class or class decorator depending on invocation style.
Notes
If the class already subclasses
Struct, registration is idempotent. Otherwise, a new class is created with MRO(raw_cls, Struct)so methods andsuper()semantics remain valid.Existing class methods are preserved. Note that
__init__is always replaced by the generated Struct constructor; any custom__init__onraw_clswill be discarded.Zero-argument
super()in class methods remains valid.Struct processing (field collection, pytree registration, class reference caching) is applied exactly once.
- fields(cls)¶
Return class field mapping.
Thin convenience wrapper over
cls.fields().
- node_fields(cls)¶
Return tuple of node-field names.
Thin convenience wrapper over
cls.node_fields().
- static_fields(cls)¶
Return tuple of static-field names.
Thin convenience wrapper over
cls.static_fields().
- opaque_fields(cls)¶
Return tuple of opaque-field names.
Thin convenience wrapper over
cls.opaque_fields().
- derived_fields(cls)¶
Return tuple of derived-field names.
Thin convenience wrapper over
cls.derived_fields().