neuralqx.experimental.utils.struct.pytree module¶
Public registration APIs for non-Struct pytree types.
These helpers let you bring arbitrary Python classes under the same pytree and
serialisation contracts used by neuralqx.utils.struct.Struct.
Not every domain type should inherit from Struct. Sometimes you need to
adapt an existing class hierarchy (third-party objects, legacy models, minimal
runtime wrappers) while still participating in:
JAX tree traversal and transformation,
Struct I/O export/load pipelines,
reference-based reconstruction across sessions.
This module provides two registration levels:
register_pytree_type()Full-control API when you want custom flatten/unflatten behaviour.register_attrs_type()Convenience API when adaptation is naturally attribute-based.
Both APIs register adapters in: - JAX pytree registry (runtime transformation support), - struct adapter registry (serialisation and lazy reconstruction support).
- class PyTreeTypeSpec(cls, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)¶
Bases:
objectComplete pytree + serialisation spec for a registered type.
Notes
A
PyTreeTypeSpecis effectively an adapter contract for one class. Runtime systems use this single specification for:tree flattening/unflattening,
optional serialiser-based persistence paths,
lazy resolution by reference at load time.
- register_pytree_type(cls, *, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)¶
Register an arbitrary class as a pytree-capable and serialisable type.
This is the low-level adapter API. Use this when you need full control over flattening and reconstruction behaviour.
- Parameters:
flatten (
Callable[[Any],tuple[Sequence[Any],Any]]) – Callable mappingobj -> (children, aux).unflatten (
Callable[[Any,Sequence[Any]],Any]) – Callable mapping(aux, children) -> obj.flatten_with_keys (
Optional[Callable[[Any],tuple[Sequence[tuple[Any,Any]],Any]]]) – Optional key-aware flatten variant for richer tree metadata.serializer (
Optional[Callable[[Any],Any]]) – Optional persistence path where rawflattenaux-data is not directly serialisable or where custom wire format is preferred.deserializer (
Optional[Callable[[Any],Any]]) – Optional persistence path where rawflattenaux-data is not directly serialisable or where custom wire format is preferred.
- Return type:
- Returns:
The original class, allowing decorator-style usage.
- Raises:
TypeError – Propagated if provided callables are invalid at runtime use.
Example
class Node: def __init__(self, data, tag): self.data = data self.tag = tag register_pytree_type( Node, flatten=lambda obj: ([obj.data], obj.tag), unflatten=lambda aux, children: Node(children[0], aux), )
- register_attrs_type(cls, *, node_fields, static_fields=(), constructor=None, serializer=None, deserializer=None)¶
Register a class by naming its node/static attributes.
This higher-level API is useful when your class is attribute-backed and you do not want to implement flatten/unflatten manually.
- Parameters:
node_fields (
Sequence[str]) – Attribute names treated as pytree leaves.static_fields (
Sequence[str]) – Attribute names treated as pytree auxiliary metadata.constructor (
Callable[[Mapping[str,Any]],Any] |None) – Optional callable used to construct instances from resolved attribute map. If omitted,cls(**values)is attempted first, then attribute fallback.serializer (
Optional[Callable[[Any],Any]]) – Optional persistence adapters (same semantics asregister_pytree_type()).deserializer (
Optional[Callable[[Any],Any]]) – Optional persistence adapters (same semantics asregister_pytree_type()).
- Return type:
- Returns:
The original class.
Notes
- This function is intentionally permissive for reconstruction:
it first tries
cls(**values),then falls back to
object.__new__(cls)+setattrassignment. This makes it robust for both dataclass-like and legacy classes.
- is_registered_pytree_type(cls)¶
Return whether
clscurrently has an active pytree adapter.This is useful for idempotent registration patterns in large applications where modules may be imported in different orders.
- Return type:
- resolve_pytree_spec(ref)¶
Resolve a pytree type spec from class reference.
- Resolution order:
In-memory ref cache.
Import class via class registry, then class cache.
Auto-register classes implementing
tree_flatten/tree_unflatten.
- Parameters:
ref (
str) – Class reference string in"module:qualname"format.- Return type:
- Returns:
Resolved adapter specification.
- Raises:
SerializationError – If no adapter is found/derivable for the given reference.
Notes
If no explicit spec is found but the resolved class implements both
tree_flattenandtree_unflatten, a default adapter is registered automatically. This mirrors common JAX ecosystem conventions.