neuralqx.experimental.utils.struct package¶
Top-level public API for struct and pytree infrastructure.
This is the user-facing namespace for the struct system. It intentionally consolidates all stable, high-level entry points required to:
Basic workflow:
import neuralqx as nqx
s = nqx.utils.struct
class State(s.Struct):
x: object
tag: str = s.field(static=True, default="default")
obj = State(x=[1, 2, 3])
payload = obj.to_state_dict()
restored = State.from_state_dict(payload)
Advanced workflow (non-Struct class registration):
class MyNode:
def __init__(self, data, label):
self.data = data
self.label = label
s.register_attrs_type(
MyNode,
node_fields=("data",),
static_fields=("label",),
)
# MyNode now participates in pytree traversal and struct I/O.
- class Struct¶
Bases:
objectImmutable, JAX-native container with explicit field semantics.
Structis designed as a forward-looking abstraction for differentiable systems code:deterministic construction and validation,
strict post-init immutability,
first-class pytree registration,
stable serialization contracts.
Field behaviour is driven by
neuralqx.utils.struct.field().- For generated constructors and reconstruction paths, the lifecycle is:
assign user/default values to non-derived fields
compute derived fields
run
__post_init__(if defined)recompute derived fields (post-init may have changed dependencies)
validate values and static constraints
freeze instance (further mutation raises
FrozenStructError)
Struct instances are logically (almost) immutable after initialisation. Use
replace()to derive updated copies.- Pytree contract:
Node fields are leaves.
Static fields are auxiliary data.
Opaque fields are auxiliary identity-preserving references.
- Serialisation contract:
- Struct exposes convenience wrappers over
neuralqx.utils.struct.io: to_state_dict()from_state_dict()export()load()
- Struct exposes convenience wrappers over
- classmethod fields()¶
Return ordered mapping of field names to
FieldSpec.This is the canonical schema view for the class. The mapping order matches declaration/MRO resolution order used by constructor generation and derived-field evaluation.
- classmethod node_fields()¶
Return names of fields traced as pytree leaves.
These fields are visible to
jax.tree_util.tree_leavesand therefore participate in transformations such asjit,grad, andvmap.
- classmethod static_fields()¶
Return names of static pytree metadata fields.
Static fields are part of pytree aux data and influence JAX cache keys. They must remain hashable and array-free.
- classmethod opaque_fields()¶
Return names of opaque runtime fields excluded from pytree leaves.
Opaque fields are useful for runtime-only objects that should not be traced or serialized by default (for example handles or ephemeral tokens).
- class StructMeta(name, bases, namespace, **kwargs)¶
Bases:
typeMetaclass that finalises classes into Struct-compatible pytrees.
- Every subclass of
Structis processed at class creation time: field definitions are collected and normalised,
constructor signature is synthesised if needed,
JAX pytree hooks are registered,
class reference is cached for serialisation.
The metaclass guarantees that class-level invariants are established exactly once when classes are defined, not at first instance construction. This provides:
predictable failure modes for invalid field declarations,
zero per-instance schema analysis overhead,
consistent runtime behaviour regardless of object creation path.
- Every subclass of
- class StructABCMeta(name, bases, namespace, **kwargs)¶
Bases:
StructMeta,ABCMetaCombined metaclass giving both Struct processing and ABC enforcement.
- dataclass(cls=None, *, name=None)¶
Dataclass-style alias for
register_class().This helper exists for readability in codebases that conceptually treat Struct definitions as immutable dataclasses with JAX semantics.
- field(*, static=False, pytree=True, default=MISSING, default_factory=MISSING, init=True, repr=True, compare=True, serialize=None, kw_only=False, doc=None, metadata=None, converter=None, validator=None, derived=None)¶
Declare metadata for one
Structfield.This function mirrors the ergonomics of
dataclasses.field()while adding pytree-aware semantics and runtime hooks (converter, validator, derived).Typical usage¶
class State(Struct): x: jax.Array tag: str = field(static=True, default="exp-A") token: object = field(pytree=False, default_factory=object) norm: float = field( static=True, init=False, derived=lambda self: float(np.asarray(self.x).sum()), )
The shape intentionally resembles
dataclasses.field(), but it adds JAX/Struct-specific controls:staticandpytreedecide field role.converterandvalidatoradd runtime hooks.derivedsupports deterministic computed fields.
- type static:
- param static:
Mark this field as static pytree metadata. Static values are part of JAX cache keys, so they must be hashable and array-free.
- type pytree:
- param pytree:
When
False, the field is opaque runtime data (not a pytree leaf).- type default:
- param default:
Concrete default value used when input is omitted.
- type default_factory:
- param default_factory:
Zero-argument callable used to lazily create defaults.
- type init:
- param init:
Include field in generated constructor.
- type repr:
- param repr:
Include field in
__repr__output.- type compare:
- param compare:
Include field in
__eq__checks.- type serialize:
- param serialize:
Override whether field participates in state dict/export I/O.
- type kw_only:
- param kw_only:
Make field keyword-only in generated constructor.
- type doc:
- param doc:
Optional per-field documentation string.
- type metadata:
- param metadata:
Free-form immutable metadata mapping for tooling.
- type converter:
- param converter:
Optional converter called on assignment paths. Accepts either
converter(value)orconverter(self, value).- type validator:
Union[Callable[...,Any],Sequence[Callable[...,Any]],None]- param validator:
Optional validator (or sequence of validators). Validators accept either
validator(value)orvalidator(self, value). ReturningFalseraisesValidationError.- type derived:
- param derived:
Optional callable for computed fields (must use
init=False). Acceptsderived()orderived(self)and is recomputed after construction, replacement, and deserialisation.- rtype:
- returns:
Frozen metadata object consumed by struct class processing.
- raises ValueError:
If declaration invariants are violated, for example: - both
defaultanddefault_factoryare provided, -static=Trueandpytree=Falseare combined, - derived-field constraints are broken.
Notes
- Validators may either:
raise exceptions directly, or
return
Falseto trigger aValidationError.
- Converters/validators support ergonomic signatures:
fn(value)fn(self, value)
- class FieldSpec(name, default=MISSING, kind=FieldKind.NODE, factory=MISSING, init=True, repr=True, compare=True, serialize=None, kw_only=False, doc=None, metadata=<factory>, converter=None, validators=(), derived=None)¶
Bases:
objectImmutable metadata describing one struct field.
FieldSpecis the normalised internal representation used by the struct class processor. In user code, these are typically created byfield().You write class annotations/defaults/
field(...)declarations.Struct metaclass collects declarations and builds
FieldSpecobjects.Runtime processing validates and uses specs to drive constructor, pytree, equality, repr, and serialisation behaviour.
Notes
FieldSpecinstances are normally created byfield(). Advanced users may construct them directly when building custom decorators.
- register_class(cls=None, *, name=None)¶
Register an existing class as a
Structtype.Supports both styles:
@register_class class Graph: nodes: ... @register_class(name="RenamedGraph") class Graph: ...
- Parameters:
- Return type:
- Returns:
Registered class or class decorator depending on invocation style.
Notes
If the class already subclasses
Struct, registration is idempotent. Otherwise, a new class is created with MRO(raw_cls, Struct)so methods andsuper()semantics remain valid.Existing class methods are preserved. Note that
__init__is always replaced by the generated Struct constructor; any custom__init__onraw_clswill be discarded.Zero-argument
super()in class methods remains valid.Struct processing (field collection, pytree registration, class reference caching) is applied exactly once.
- fields(cls)¶
Return class field mapping.
Thin convenience wrapper over
cls.fields().
- node_fields(cls)¶
Return tuple of node-field names.
Thin convenience wrapper over
cls.node_fields().
- static_fields(cls)¶
Return tuple of static-field names.
Thin convenience wrapper over
cls.static_fields().
- opaque_fields(cls)¶
Return tuple of opaque-field names.
Thin convenience wrapper over
cls.opaque_fields().
- derived_fields(cls)¶
Return tuple of derived-field names.
Thin convenience wrapper over
cls.derived_fields().
- class_ref(cls)¶
Return canonical class reference string for
cls.- Parameters:
- Return type:
- Returns:
A deterministic
"module:qualname"reference suitable for manifests.
Notes
The returned value is purely structural and does not validate importability at call time. Import validation occurs during
resolve_class().
- resolve_class(ref)¶
Resolve a class reference into a runtime class object.
- This functions resolves by:
Look up
refin the in-memory registry cache.If missing, import module + qualified name lazily.
Cache resolved result for future calls.
- Parameters:
ref (
str) – Class reference in"module:qualname"format.- Return type:
- Returns:
Imported and validated class object.
- Raises:
SerializationError – If the reference format is invalid, import fails, attribute traversal fails, or the resolved object is not a class.
- class PyTreeTypeSpec(cls, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)¶
Bases:
objectComplete pytree + serialisation spec for a registered type.
Notes
A
PyTreeTypeSpecis effectively an adapter contract for one class. Runtime systems use this single specification for:tree flattening/unflattening,
optional serialiser-based persistence paths,
lazy resolution by reference at load time.
- register_pytree_type(cls, *, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)¶
Register an arbitrary class as a pytree-capable and serialisable type.
This is the low-level adapter API. Use this when you need full control over flattening and reconstruction behaviour.
- Parameters:
flatten (
Callable[[Any],tuple[Sequence[Any],Any]]) – Callable mappingobj -> (children, aux).unflatten (
Callable[[Any,Sequence[Any]],Any]) – Callable mapping(aux, children) -> obj.flatten_with_keys (
Optional[Callable[[Any],tuple[Sequence[tuple[Any,Any]],Any]]]) – Optional key-aware flatten variant for richer tree metadata.serializer (
Optional[Callable[[Any],Any]]) – Optional persistence path where rawflattenaux-data is not directly serialisable or where custom wire format is preferred.deserializer (
Optional[Callable[[Any],Any]]) – Optional persistence path where rawflattenaux-data is not directly serialisable or where custom wire format is preferred.
- Return type:
- Returns:
The original class, allowing decorator-style usage.
- Raises:
TypeError – Propagated if provided callables are invalid at runtime use.
Example
class Node: def __init__(self, data, tag): self.data = data self.tag = tag register_pytree_type( Node, flatten=lambda obj: ([obj.data], obj.tag), unflatten=lambda aux, children: Node(children[0], aux), )
- register_attrs_type(cls, *, node_fields, static_fields=(), constructor=None, serializer=None, deserializer=None)¶
Register a class by naming its node/static attributes.
This higher-level API is useful when your class is attribute-backed and you do not want to implement flatten/unflatten manually.
- Parameters:
node_fields (
Sequence[str]) – Attribute names treated as pytree leaves.static_fields (
Sequence[str]) – Attribute names treated as pytree auxiliary metadata.constructor (
Callable[[Mapping[str,Any]],Any] |None) – Optional callable used to construct instances from resolved attribute map. If omitted,cls(**values)is attempted first, then attribute fallback.serializer (
Optional[Callable[[Any],Any]]) – Optional persistence adapters (same semantics asregister_pytree_type()).deserializer (
Optional[Callable[[Any],Any]]) – Optional persistence adapters (same semantics asregister_pytree_type()).
- Return type:
- Returns:
The original class.
Notes
- This function is intentionally permissive for reconstruction:
it first tries
cls(**values),then falls back to
object.__new__(cls)+setattrassignment. This makes it robust for both dataclass-like and legacy classes.
- is_registered_pytree_type(cls)¶
Return whether
clscurrently has an active pytree adapter.This is useful for idempotent registration patterns in large applications where modules may be imported in different orders.
- Return type:
- resolve_pytree_spec(ref)¶
Resolve a pytree type spec from class reference.
- Resolution order:
In-memory ref cache.
Import class via class registry, then class cache.
Auto-register classes implementing
tree_flatten/tree_unflatten.
- Parameters:
ref (
str) – Class reference string in"module:qualname"format.- Return type:
- Returns:
Resolved adapter specification.
- Raises:
SerializationError – If no adapter is found/derivable for the given reference.
Notes
If no explicit spec is found but the resolved class implements both
tree_flattenandtree_unflatten, a default adapter is registered automatically. This mirrors common JAX ecosystem conventions.