neuralqx.experimental.utils.struct package

Top-level public API for struct and pytree infrastructure.

This is the user-facing namespace for the struct system. It intentionally consolidates all stable, high-level entry points required to:

  1. Define immutable, JAX-native data containers via Struct.

  2. Declare rich field behaviour via field() and FieldSpec.

  3. Register non-Struct Python classes as pytrees.

  4. Serialize and restore Struct/registered-pytree object graphs.

  5. Resolve classes and adapters by stable runtime references.

Basic workflow:

import neuralqx as nqx
s = nqx.utils.struct

class State(s.Struct):
    x: object
    tag: str = s.field(static=True, default="default")

obj = State(x=[1, 2, 3])
payload = obj.to_state_dict()
restored = State.from_state_dict(payload)

Advanced workflow (non-Struct class registration):

class MyNode:
    def __init__(self, data, label):
        self.data = data
        self.label = label

s.register_attrs_type(
    MyNode,
    node_fields=("data",),
    static_fields=("label",),
)

# MyNode now participates in pytree traversal and struct I/O.
class Struct

Bases: object

Immutable, JAX-native container with explicit field semantics.

Struct is designed as a forward-looking abstraction for differentiable systems code:

  • deterministic construction and validation,

  • strict post-init immutability,

  • first-class pytree registration,

  • stable serialization contracts.

Field behaviour is driven by neuralqx.utils.struct.field().

For generated constructors and reconstruction paths, the lifecycle is:
  1. assign user/default values to non-derived fields

  2. compute derived fields

  3. run __post_init__ (if defined)

  4. recompute derived fields (post-init may have changed dependencies)

  5. validate values and static constraints

  6. freeze instance (further mutation raises FrozenStructError)

Struct instances are logically (almost) immutable after initialisation. Use replace() to derive updated copies.

Pytree contract:
  • Node fields are leaves.

  • Static fields are auxiliary data.

  • Opaque fields are auxiliary identity-preserving references.

Serialisation contract:
Struct exposes convenience wrappers over neuralqx.utils.struct.io:
  • to_state_dict()

  • from_state_dict()

  • export()

  • load()

classmethod fields()

Return ordered mapping of field names to FieldSpec.

This is the canonical schema view for the class. The mapping order matches declaration/MRO resolution order used by constructor generation and derived-field evaluation.

Return type:

Mapping[str, FieldSpec]

classmethod node_fields()

Return names of fields traced as pytree leaves.

These fields are visible to jax.tree_util.tree_leaves and therefore participate in transformations such as jit, grad, and vmap.

Return type:

tuple[str, ...]

classmethod static_fields()

Return names of static pytree metadata fields.

Static fields are part of pytree aux data and influence JAX cache keys. They must remain hashable and array-free.

Return type:

tuple[str, ...]

classmethod opaque_fields()

Return names of opaque runtime fields excluded from pytree leaves.

Opaque fields are useful for runtime-only objects that should not be traced or serialized by default (for example handles or ephemeral tokens).

Return type:

tuple[str, ...]

classmethod derived_fields()

Return names of fields computed from other fields.

Derived fields are recomputed automatically during construction, replacement, and deserialisation.

Return type:

tuple[str, ...]

class StructMeta(name, bases, namespace, **kwargs)

Bases: type

Metaclass that finalises classes into Struct-compatible pytrees.

Every subclass of Struct is processed at class creation time:
  • field definitions are collected and normalised,

  • constructor signature is synthesised if needed,

  • JAX pytree hooks are registered,

  • class reference is cached for serialisation.

The metaclass guarantees that class-level invariants are established exactly once when classes are defined, not at first instance construction. This provides:

  • predictable failure modes for invalid field declarations,

  • zero per-instance schema analysis overhead,

  • consistent runtime behaviour regardless of object creation path.

class StructABCMeta(name, bases, namespace, **kwargs)

Bases: StructMeta, ABCMeta

Combined metaclass giving both Struct processing and ABC enforcement.

dataclass(cls=None, *, name=None)

Dataclass-style alias for register_class().

This helper exists for readability in codebases that conceptually treat Struct definitions as immutable dataclasses with JAX semantics.

Return type:

type[Any] | Callable[[type[Any]], type[Any]]

field(*, static=False, pytree=True, default=MISSING, default_factory=MISSING, init=True, repr=True, compare=True, serialize=None, kw_only=False, doc=None, metadata=None, converter=None, validator=None, derived=None)

Declare metadata for one Struct field.

This function mirrors the ergonomics of dataclasses.field() while adding pytree-aware semantics and runtime hooks (converter, validator, derived).

Typical usage

class State(Struct):
    x: jax.Array
    tag: str = field(static=True, default="exp-A")
    token: object = field(pytree=False, default_factory=object)
    norm: float = field(
        static=True,
        init=False,
        derived=lambda self: float(np.asarray(self.x).sum()),
    )

The shape intentionally resembles dataclasses.field(), but it adds JAX/Struct-specific controls:

  • static and pytree decide field role.

  • converter and validator add runtime hooks.

  • derived supports deterministic computed fields.

type static:

bool

param static:

Mark this field as static pytree metadata. Static values are part of JAX cache keys, so they must be hashable and array-free.

type pytree:

bool

param pytree:

When False, the field is opaque runtime data (not a pytree leaf).

type default:

Any

param default:

Concrete default value used when input is omitted.

type default_factory:

Callable[[], Any] | MissingType

param default_factory:

Zero-argument callable used to lazily create defaults.

type init:

bool

param init:

Include field in generated constructor.

type repr:

bool

param repr:

Include field in __repr__ output.

type compare:

bool

param compare:

Include field in __eq__ checks.

type serialize:

bool | None

param serialize:

Override whether field participates in state dict/export I/O.

type kw_only:

bool

param kw_only:

Make field keyword-only in generated constructor.

type doc:

str | None

param doc:

Optional per-field documentation string.

type metadata:

Mapping[str, Any] | None

param metadata:

Free-form immutable metadata mapping for tooling.

type converter:

Callable[..., Any] | None

param converter:

Optional converter called on assignment paths. Accepts either converter(value) or converter(self, value).

type validator:

Union[Callable[..., Any], Sequence[Callable[..., Any]], None]

param validator:

Optional validator (or sequence of validators). Validators accept either validator(value) or validator(self, value). Returning False raises ValidationError.

type derived:

Callable[..., Any] | None

param derived:

Optional callable for computed fields (must use init=False). Accepts derived() or derived(self) and is recomputed after construction, replacement, and deserialisation.

rtype:

FieldSpec

returns:

Frozen metadata object consumed by struct class processing.

raises ValueError:

If declaration invariants are violated, for example: - both default and default_factory are provided, - static=True and pytree=False are combined, - derived-field constraints are broken.

Notes

Validators may either:
  • raise exceptions directly, or

  • return False to trigger a ValidationError.

Converters/validators support ergonomic signatures:
  • fn(value)

  • fn(self, value)

class FieldKind(value)

Bases: str, Enum

Enum describing a field’s role in Struct/pytree behaviour.

class FieldSpec(name, default=MISSING, kind=FieldKind.NODE, factory=MISSING, init=True, repr=True, compare=True, serialize=None, kw_only=False, doc=None, metadata=<factory>, converter=None, validators=(), derived=None)

Bases: object

Immutable metadata describing one struct field.

FieldSpec is the normalised internal representation used by the struct class processor. In user code, these are typically created by field().

  1. You write class annotations/defaults/field(...) declarations.

  2. Struct metaclass collects declarations and builds FieldSpec objects.

  3. Runtime processing validates and uses specs to drive constructor, pytree, equality, repr, and serialisation behaviour.

Notes

FieldSpec instances are normally created by field(). Advanced users may construct them directly when building custom decorators.

register_class(cls=None, *, name=None)

Register an existing class as a Struct type.

Supports both styles:

@register_class
class Graph:
    nodes: ...

@register_class(name="RenamedGraph")
class Graph:
    ...
Parameters:
  • cls (type[Any] | None) – Class to register. When omitted, returns a decorator.

  • name (str | None) – Optional override for generated class name.

Return type:

type[Any] | Callable[[type[Any]], type[Any]]

Returns:

Registered class or class decorator depending on invocation style.

Notes

If the class already subclasses Struct, registration is idempotent. Otherwise, a new class is created with MRO (raw_cls, Struct) so methods and super() semantics remain valid.

  • Existing class methods are preserved. Note that __init__ is always replaced by the generated Struct constructor; any custom __init__ on raw_cls will be discarded.

  • Zero-argument super() in class methods remains valid.

  • Struct processing (field collection, pytree registration, class reference caching) is applied exactly once.

fields(cls)

Return class field mapping.

Thin convenience wrapper over cls.fields().

node_fields(cls)

Return tuple of node-field names.

Thin convenience wrapper over cls.node_fields().

static_fields(cls)

Return tuple of static-field names.

Thin convenience wrapper over cls.static_fields().

opaque_fields(cls)

Return tuple of opaque-field names.

Thin convenience wrapper over cls.opaque_fields().

derived_fields(cls)

Return tuple of derived-field names.

Thin convenience wrapper over cls.derived_fields().

class_ref(cls)

Return canonical class reference string for cls.

Parameters:

cls (type[Any]) – Runtime class object to encode.

Return type:

str

Returns:

A deterministic "module:qualname" reference suitable for manifests.

Notes

The returned value is purely structural and does not validate importability at call time. Import validation occurs during resolve_class().

resolve_class(ref)

Resolve a class reference into a runtime class object.

This functions resolves by:
  1. Look up ref in the in-memory registry cache.

  2. If missing, import module + qualified name lazily.

  3. Cache resolved result for future calls.

Parameters:

ref (str) – Class reference in "module:qualname" format.

Return type:

type[Any]

Returns:

Imported and validated class object.

Raises:

SerializationError – If the reference format is invalid, import fails, attribute traversal fails, or the resolved object is not a class.

class PyTreeTypeSpec(cls, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)

Bases: object

Complete pytree + serialisation spec for a registered type.

Notes

A PyTreeTypeSpec is effectively an adapter contract for one class. Runtime systems use this single specification for:

  • tree flattening/unflattening,

  • optional serialiser-based persistence paths,

  • lazy resolution by reference at load time.

register_pytree_type(cls, *, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)

Register an arbitrary class as a pytree-capable and serialisable type.

This is the low-level adapter API. Use this when you need full control over flattening and reconstruction behaviour.

Parameters:
Return type:

type[Any]

Returns:

The original class, allowing decorator-style usage.

Raises:

TypeError – Propagated if provided callables are invalid at runtime use.

Example

class Node:
    def __init__(self, data, tag):
        self.data = data
        self.tag = tag

register_pytree_type(
    Node,
    flatten=lambda obj: ([obj.data], obj.tag),
    unflatten=lambda aux, children: Node(children[0], aux),
)
register_attrs_type(cls, *, node_fields, static_fields=(), constructor=None, serializer=None, deserializer=None)

Register a class by naming its node/static attributes.

This higher-level API is useful when your class is attribute-backed and you do not want to implement flatten/unflatten manually.

Parameters:
Return type:

type[Any]

Returns:

The original class.

Notes

This function is intentionally permissive for reconstruction:
  • it first tries cls(**values),

  • then falls back to object.__new__(cls) + setattr assignment. This makes it robust for both dataclass-like and legacy classes.

is_registered_pytree_type(cls)

Return whether cls currently has an active pytree adapter.

This is useful for idempotent registration patterns in large applications where modules may be imported in different orders.

Return type:

bool

resolve_pytree_spec(ref)

Resolve a pytree type spec from class reference.

Resolution order:
  1. In-memory ref cache.

  2. Import class via class registry, then class cache.

  3. Auto-register classes implementing tree_flatten/tree_unflatten.

Parameters:

ref (str) – Class reference string in "module:qualname" format.

Return type:

PyTreeTypeSpec

Returns:

Resolved adapter specification.

Raises:

SerializationError – If no adapter is found/derivable for the given reference.

Notes

If no explicit spec is found but the resolved class implements both tree_flatten and tree_unflatten, a default adapter is registered automatically. This mirrors common JAX ecosystem conventions.

Submodules