neuralqx.experimental.utils.struct.base module

Public struct base classes and decorators.

This module defines:
  • StructMeta: metaclass that processes classes into immutable pytrees.

  • Struct: ergonomic user-facing base class.

  • decorator helpers for registering existing classes as structs.

This module is the primary user entry point for defining immutable, traceable state containers that behave well under JAX transformations and neuralQX serialisation flows.

Compared to plain dataclasses, Struct adds:
  • explicit node/static/opaque field semantics,

  • deterministic freeze/validation lifecycle,

  • built-in pytree registration,

  • integrated state export/load methods.

Example

import neuralqx as nqx
s = nqx.utils.struct

class State(s.Struct):
    params: object
    name: str = s.field(static=True, default="exp")

st = State(params={"w": [1, 2, 3]})
st2 = st.replace(name="exp-2")
class Struct

Bases: object

Immutable, JAX-native container with explicit field semantics.

Struct is designed as a forward-looking abstraction for differentiable systems code:

  • deterministic construction and validation,

  • strict post-init immutability,

  • first-class pytree registration,

  • stable serialization contracts.

Field behaviour is driven by neuralqx.utils.struct.field().

For generated constructors and reconstruction paths, the lifecycle is:
  1. assign user/default values to non-derived fields

  2. compute derived fields

  3. run __post_init__ (if defined)

  4. recompute derived fields (post-init may have changed dependencies)

  5. validate values and static constraints

  6. freeze instance (further mutation raises FrozenStructError)

Struct instances are logically (almost) immutable after initialisation. Use replace() to derive updated copies.

Pytree contract:
  • Node fields are leaves.

  • Static fields are auxiliary data.

  • Opaque fields are auxiliary identity-preserving references.

Serialisation contract:
Struct exposes convenience wrappers over neuralqx.utils.struct.io:
  • to_state_dict()

  • from_state_dict()

  • export()

  • load()

classmethod fields()

Return ordered mapping of field names to FieldSpec.

This is the canonical schema view for the class. The mapping order matches declaration/MRO resolution order used by constructor generation and derived-field evaluation.

Return type:

Mapping[str, FieldSpec]

classmethod node_fields()

Return names of fields traced as pytree leaves.

These fields are visible to jax.tree_util.tree_leaves and therefore participate in transformations such as jit, grad, and vmap.

Return type:

tuple[str, ...]

classmethod static_fields()

Return names of static pytree metadata fields.

Static fields are part of pytree aux data and influence JAX cache keys. They must remain hashable and array-free.

Return type:

tuple[str, ...]

classmethod opaque_fields()

Return names of opaque runtime fields excluded from pytree leaves.

Opaque fields are useful for runtime-only objects that should not be traced or serialized by default (for example handles or ephemeral tokens).

Return type:

tuple[str, ...]

classmethod derived_fields()

Return names of fields computed from other fields.

Derived fields are recomputed automatically during construction, replacement, and deserialisation.

Return type:

tuple[str, ...]

class StructMeta(name, bases, namespace, **kwargs)

Bases: type

Metaclass that finalises classes into Struct-compatible pytrees.

Every subclass of Struct is processed at class creation time:
  • field definitions are collected and normalised,

  • constructor signature is synthesised if needed,

  • JAX pytree hooks are registered,

  • class reference is cached for serialisation.

The metaclass guarantees that class-level invariants are established exactly once when classes are defined, not at first instance construction. This provides:

  • predictable failure modes for invalid field declarations,

  • zero per-instance schema analysis overhead,

  • consistent runtime behaviour regardless of object creation path.

class StructABCMeta(name, bases, namespace, **kwargs)

Bases: StructMeta, ABCMeta

Combined metaclass giving both Struct processing and ABC enforcement.

dataclass(cls=None, *, name=None)

Dataclass-style alias for register_class().

This helper exists for readability in codebases that conceptually treat Struct definitions as immutable dataclasses with JAX semantics.

Return type:

type[Any] | Callable[[type[Any]], type[Any]]

register_class(cls=None, *, name=None)

Register an existing class as a Struct type.

Supports both styles:

@register_class
class Graph:
    nodes: ...

@register_class(name="RenamedGraph")
class Graph:
    ...
Parameters:
  • cls (type[Any] | None) – Class to register. When omitted, returns a decorator.

  • name (str | None) – Optional override for generated class name.

Return type:

type[Any] | Callable[[type[Any]], type[Any]]

Returns:

Registered class or class decorator depending on invocation style.

Notes

If the class already subclasses Struct, registration is idempotent. Otherwise, a new class is created with MRO (raw_cls, Struct) so methods and super() semantics remain valid.

  • Existing class methods are preserved. Note that __init__ is always replaced by the generated Struct constructor; any custom __init__ on raw_cls will be discarded.

  • Zero-argument super() in class methods remains valid.

  • Struct processing (field collection, pytree registration, class reference caching) is applied exactly once.

fields(cls)

Return class field mapping.

Thin convenience wrapper over cls.fields().

node_fields(cls)

Return tuple of node-field names.

Thin convenience wrapper over cls.node_fields().

static_fields(cls)

Return tuple of static-field names.

Thin convenience wrapper over cls.static_fields().

opaque_fields(cls)

Return tuple of opaque-field names.

Thin convenience wrapper over cls.opaque_fields().

derived_fields(cls)

Return tuple of derived-field names.

Thin convenience wrapper over cls.derived_fields().