neuralqx.experimental.utils.struct.pytree module

Public registration APIs for non-Struct pytree types.

These helpers let you bring arbitrary Python classes under the same pytree and serialisation contracts used by neuralqx.utils.struct.Struct.

Not every domain type should inherit from Struct. Sometimes you need to adapt an existing class hierarchy (third-party objects, legacy models, minimal runtime wrappers) while still participating in:

  • JAX tree traversal and transformation,

  • Struct I/O export/load pipelines,

  • reference-based reconstruction across sessions.

This module provides two registration levels:

  1. register_pytree_type() Full-control API when you want custom flatten/unflatten behaviour.

  2. register_attrs_type() Convenience API when adaptation is naturally attribute-based.

Both APIs register adapters in: - JAX pytree registry (runtime transformation support), - struct adapter registry (serialisation and lazy reconstruction support).

class PyTreeTypeSpec(cls, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)

Bases: object

Complete pytree + serialisation spec for a registered type.

Notes

A PyTreeTypeSpec is effectively an adapter contract for one class. Runtime systems use this single specification for:

  • tree flattening/unflattening,

  • optional serialiser-based persistence paths,

  • lazy resolution by reference at load time.

register_pytree_type(cls, *, flatten, unflatten, flatten_with_keys=None, serializer=None, deserializer=None)

Register an arbitrary class as a pytree-capable and serialisable type.

This is the low-level adapter API. Use this when you need full control over flattening and reconstruction behaviour.

Parameters:
Return type:

type[Any]

Returns:

The original class, allowing decorator-style usage.

Raises:

TypeError – Propagated if provided callables are invalid at runtime use.

Example

class Node:
    def __init__(self, data, tag):
        self.data = data
        self.tag = tag

register_pytree_type(
    Node,
    flatten=lambda obj: ([obj.data], obj.tag),
    unflatten=lambda aux, children: Node(children[0], aux),
)
register_attrs_type(cls, *, node_fields, static_fields=(), constructor=None, serializer=None, deserializer=None)

Register a class by naming its node/static attributes.

This higher-level API is useful when your class is attribute-backed and you do not want to implement flatten/unflatten manually.

Parameters:
Return type:

type[Any]

Returns:

The original class.

Notes

This function is intentionally permissive for reconstruction:
  • it first tries cls(**values),

  • then falls back to object.__new__(cls) + setattr assignment. This makes it robust for both dataclass-like and legacy classes.

is_registered_pytree_type(cls)

Return whether cls currently has an active pytree adapter.

This is useful for idempotent registration patterns in large applications where modules may be imported in different orders.

Return type:

bool

resolve_pytree_spec(ref)

Resolve a pytree type spec from class reference.

Resolution order:
  1. In-memory ref cache.

  2. Import class via class registry, then class cache.

  3. Auto-register classes implementing tree_flatten/tree_unflatten.

Parameters:

ref (str) – Class reference string in "module:qualname" format.

Return type:

PyTreeTypeSpec

Returns:

Resolved adapter specification.

Raises:

SerializationError – If no adapter is found/derivable for the given reference.

Notes

If no explicit spec is found but the resolved class implements both tree_flatten and tree_unflatten, a default adapter is registered automatically. This mirrors common JAX ecosystem conventions.