Structs and data containers¶
Warning
Experimental feature.
The neuraLQX immutable Struct is not part of the stable neuraLQX public API. Its internals, functionality and behaviour may change without prior warning or deprecation period. Pin your neuraLQX version if you depend on it in research scripts.
See Experimental API for the full experimental-API policy.
neuralqx.experimental.utils.struct is the foundation for immutable, JAX-native data
containers in neuraLQX. It provides a single base class Struct
whose subclasses are automatically:
registered as JAX pytrees with correct leaf/auxiliary-data split,
frozen after construction so they can be safely used as JAX static arguments and in functional code,
equipped with a stable serialisation contract (in-memory state dicts, directory bundles,
.ziparchives),identified by a portable class reference so serialised payloads can be reconstructed across processes.
The field system that drives all of this behaviour is declared with a single
field() helper that mirrors dataclasses.field()
but adds JAX-specific controls.
import neuralqx.experimental.utils.struct as s
class TrainState(s.Struct):
params: object # JAX array tree (node field)
step: int = s.field(static=True, default=0) # static metadata
tag: str = s.field(static=True, default="run")
state = TrainState(params={"w": jnp.ones(4)})
state2 = state.replace(step=1) # new instance, original unchanged
payload = state.to_state_dict() # portable mapping
state3 = TrainState.from_state_dict(payload)
Import path:
import neuralqx.experimental.utils.struct as s
# or pick individual names:
from neuralqx.experimental.utils.struct import Struct, field, FieldKind, register_class
Why Struct?¶
Plain Python objects do not compose well with JAX. JAX transformations
(jit, grad, vmap) can only trace through values that are
registered as pytrees. If you pass an arbitrary object as an argument to a
jitted function, JAX sees it as a static constant, changes to its fields are
invisible to the tracer, and you lose the ability to differentiate through it.
Python’s standard dataclasses module does not know about this.
dataclasses.dataclass() gives you a nice constructor and __repr__,
but it does not register anything with JAX.
Struct closes this gap:
Every ``Struct`` subclass is a JAX pytree registered automatically at class-definition time. JAX can trace through its node fields (the pytree leaves) while keeping static fields as compile-time constants.
Instances are immutable after construction. Attempting to set an attribute raises
FrozenStructError. This makesStructinstances safe to use as pytree auxiliary data, asjitstatic arguments, and as dictionary keys.Serialisation is built in.
to_state_dict/from_state_dictandexport/loadwork automatically for anyStruct, including nested ones, without extra configuration.
The field kinds (node, static, opaque) are the key design decision. Getting
them right is what makes a Struct behave correctly under JAX
transformations.
Field kinds¶
Every field in a Struct has a kind that controls how it participates in
JAX pytree traversal. The three kinds are expressed by
FieldKind:
Kind |
|
What it means |
|---|---|---|
Node |
|
The field is a pytree leaf. JAX traces through it. Gradients can
flow through it, |
Static |
|
The field becomes pytree auxiliary data (the |
Opaque |
|
The field is excluded from the pytree entirely. JAX does not see it. Use this for runtime handles, caches, tokens, or any process-local object that should not be traced or serialised. |
Choosing correctly:
Use node for anything JAX should differentiate through: weight arrays, configuration tensors, floating-point scalars that change across calls.
Use static for things that shape your computation but are fixed during a training run: hyperparameters, network architecture choices, string labels, integer sizes. Every unique combination of static field values triggers a separate
jitcompilation.Use opaque for runtime-only side-data that would confuse JAX: file handles, logger objects, reference caches, non-serialisable callbacks.
class Config(s.Struct):
weights: object # NODE: JAX traces this
n_layers: int = s.field(static=True) # STATIC: re-traces on change
cache: object = s.field(pytree=False, default_factory=dict) # OPAQUE: invisible to JAX
The field() declaration helper¶
field() is the single function used to control
field behaviour. It returns a FieldSpec that
the metaclass picks up at class-definition time.
Signature (all arguments are keyword-only):
s.field(
*,
static=False, # True -> FieldKind.STATIC
pytree=True, # False -> FieldKind.OPAQUE
default=MISSING, # concrete default value
default_factory=MISSING, # zero-arg callable for mutable defaults
init=True, # include in generated constructor
repr=True, # include in __repr__
compare=True, # include in __eq__
serialize=None, # override serialisation (None = inferred)
kw_only=False, # force keyword-only in constructor
doc=None, # per-field docstring
metadata=None, # free-form mapping for tooling
converter=None, # normalisation hook (see below)
validator=None, # validation hook(s) (see below)
derived=None, # computed-field function (see below)
)
Common patterns:
class State(s.Struct):
# plain node field, no annotation needed
params: object
# static field with a default
name: str = s.field(static=True, default="run-1")
# opaque field, invisible to JAX, not serialised by default
logger: object = s.field(pytree=False, default_factory=object)
# mutable container default, avoids shared-state bugs
history: list = s.field(default_factory=list)
# keyword-only in constructor
seed: int = s.field(static=True, kw_only=True, default=0)
# exclude from repr and equality
debug_token: str = s.field(
static=True, repr=False, compare=False, default=""
)
Default rules:
defaultanddefault_factoryare mutually exclusive.Fields without either must receive a value in the constructor.
default_factoryis called freshly each time a default is needed — this is the correct way to declare mutable defaults (lists, dicts, etc.).
Converter and validator hooks¶
Converters¶
A converter is called on the raw value every time the field is assigned
in the constructor, in replace, and during deserialisation. Use it to
normalise input types:
class Box(s.Struct):
count: int = s.field(converter=int) # coerce str to int on input
label: str = s.field(converter=str.strip) # strip whitespace
b = Box(count="3", label=" hello ")
b.count # 3 (int)
b.label # "hello"
Converters accept either converter(value) or converter(self, value).
The two-argument form lets the converter inspect other already-assigned fields:
def _clamp(self, value):
return max(0, min(value, self.maximum))
class Bounded(s.Struct):
maximum: int
value: int = s.field(converter=_clamp)
Validators¶
A validator is called after the converter (if any) with the normalised
value. If it returns False, a ValidationError
is raised. If it raises an exception directly, that exception propagates.
def _positive(value):
return value > 0
class Rate(s.Struct):
lr: float = s.field(validator=_positive)
Rate(lr=0.01) # fine
Rate(lr=-1.0) # raises ValidationError
Pass a list to apply multiple validators in order:
def _finite(value):
import math
return math.isfinite(value)
class Rate(s.Struct):
lr: float = s.field(validator=[_positive, _finite])
Like converters, validators may accept validator(value) or
validator(self, value).
Derived fields¶
A derived field is computed automatically from other fields. It is never part of the constructor, never serialised (the value is recomputed instead), and must be either static or opaque (not a node field).
Declare it with init=False and a derived callable:
class Dataset(s.Struct):
samples: list
n: int = s.field(
static=True,
init=False,
derived=lambda self: len(self.samples),
)
ds = Dataset(samples=[1, 2, 3])
ds.n # 3, computed automatically
The derived callable receives either no arguments (derived()) or the
instance (derived(self)). It is called:
after initial construction,
after
replace(),after
from_state_dict()/load(),explicitly when you call
rederive().
Derived fields run twice during construction: once before __post_init__
and once after (in case __post_init__ modifies source fields).
Construction lifecycle¶
When you call MyStruct(...) the following steps happen in order:
Assign user/default values to all non-derived, non-init fields and all init fields from the provided keyword arguments. Converters are applied at this point.
Compute derived fields for the first time.
Run ``__post_init__`` if defined. You may freely mutate fields inside
__post_init__, the instance is not frozen yet.Recompute derived fields again (in case
__post_init__changed source fields).Validate static constraints (static fields must be hashable and array-free) and field validators.
Freeze the instance. Any subsequent
setattrordelattrraisesFrozenStructError.
class Model(s.Struct):
x: int
y: int = s.field(static=True, init=False, derived=lambda self: self.x * 2)
def __post_init__(self):
# called at step 3, can still mutate
self.x = self.x + 1
m = Model(x=2)
m.x # 3 (post_init incremented it)
m.y # 6 (derived after post_init)
replace() and rederive()¶
replace¶
Because Struct instances are frozen, the way to “update” one is to create
a new instance with selected fields changed:
state2 = state.replace(step=state.step + 1, tag="run-2")
replace runs the full lifecycle again, converters, __post_init__,
derived fields, validators, freeze. You cannot replace a derived field
directly, update its source fields instead.
# Wrong, raises TypeError:
ds.replace(n=10)
# Correct, update the source and n is recomputed:
ds.replace(samples=[1, 2, 3, 4])
rederive¶
Occasionally a node field holds a mutable container (a list, a dict)
that has changed in place. In this case derived fields that depend on it
are stale. rederive() recomputes all derived fields without rebuilding
the whole instance:
class Bag(s.Struct):
items: list
size: int = s.field(static=True, init=False, derived=lambda self: len(self.items))
bag = Bag(items=[1, 2])
bag.items.append(3) # mutate in place
bag.size # 2 (stale!)
bag.rederive()
bag.size # 3
Note
In most JAX code you would avoid in-place mutation entirely and use
replace instead. rederive exists for the rare cases where a
mutable container is intentional (e.g. a growing replay buffer).
Introspection helpers¶
Both class-level and module-level forms exist for convenience:
# class methods
State.fields() # OrderedDict of name -> FieldSpec
State.node_fields() # ("params",)
State.static_fields() # ("step", "tag")
State.opaque_fields() # ()
State.derived_fields() # ()
# module-level wrappers (same result)
s.fields(State)
s.node_fields(State)
s.static_fields(State)
s.opaque_fields(State)
s.derived_fields(State)
Inspecting a FieldSpec:
spec = State.fields()["step"]
spec.name # "step"
spec.kind # FieldKind.STATIC
spec.default # 0
spec.has_default # True
spec.is_derived # False
spec.should_serialize # True
Counting pytree leaves:
state.tree_size() # == len(jax.tree_util.tree_leaves(state))
Converting to a plain dict (for debugging/logging, not persistence):
state.to_dict() # shallow, values are raw Python objects
state.to_dict(recursive=True) # recurse through nested Structs
state.to_dict(include_opaque=False) # skip opaque fields
Serialisation¶
Struct provides four serialisation methods. All handle nested Struct
values, registered pytrees, and NumPy/JAX arrays automatically.
to_state_dict / from_state_dict¶
Produces and consumes a portable in-memory Python mapping:
payload = state.to_state_dict()
# {
# "version": 1,
# "manifest": {...}, # JSON-safe structural encoding
# "arrays": {...}, # shape/dtype metadata
# "array_data": {...}, # in-memory NumPy arrays
# }
restored = TrainState.from_state_dict(payload)
The payload is suitable for:
sending over the network (after serialising
array_dataseparately),passing to logging systems,
checkpointing in memory without writing to disk.
Derived fields are not stored, they are recomputed during
from_state_dict. Opaque fields are not stored by default.
export / load¶
Writes a two-file bundle to disk and reads it back:
# Write to a directory:
state.export("checkpoints/step_100")
# Creates:
# checkpoints/step_100/manifest.json
# checkpoints/step_100/arrays.npz
# Or write a zip archive:
state.export("checkpoints/step_100.zip")
# Load back (class is inferred from the manifest):
restored = TrainState.load("checkpoints/step_100")
# Optionally enforce the expected class:
restored = TrainState.load("checkpoints/step_100", load_cls=...)
The manifest.json contains the structural encoding including the class
reference ("module:qualname" string). The arrays.npz contains all
array payloads compressed with NumPy’s compressed format. Both files together
are a self-contained checkpoint.
export raises FileExistsError if the target already exists unless you
pass overwrite=True.
Serialisation policy per field¶
By default:
Node fields are serialised.
Static fields are serialised.
Opaque fields are not serialised.
Derived fields are never serialised (recomputed on load).
You can override the policy for any individual field with serialize=True
or serialize=False:
class State(s.Struct):
params: object
cache: object = s.field(pytree=False, serialize=True) # opaque but saved
debug: str = s.field(static=True, serialize=False) # static but skipped
Registering non-Struct classes¶
Not every class needs to (or should) inherit from Struct. The package
provides two APIs to register existing classes so they participate in JAX
pytree traversal and struct I/O without changing their inheritance hierarchy.
register_class decorator¶
register_class() is a decorator (or direct call)
that promotes an existing class to a Struct. The original class becomes
the first entry in the MRO so its methods and super() calls remain valid:
@s.register_class
class Params:
weights: object
bias: object = s.field(default_factory=lambda: jnp.zeros(1))
p = Params(weights=jnp.ones(4))
p.replace(bias=jnp.zeros(4)) # Struct API available
With an optional name override:
@s.register_class(name="TrainingParams")
class Params:
...
dataclass() is an alias for
register_class provided for readability in codebases that conceptually
think of structs as immutable dataclasses:
@s.dataclass
class MyConfig:
lr: float = s.field(static=True, default=1e-3)
register_pytree_type: full-control API¶
Use register_pytree_type() when you need
complete control over the flatten/unflatten logic for a third-party or
legacy class:
class Node:
def __init__(self, data, tag):
self.data = data
self.tag = tag
s.register_pytree_type(
Node,
flatten=lambda obj: ([obj.data], obj.tag),
unflatten=lambda aux, children: Node(children[0], aux),
)
# Node is now a JAX pytree:
import jax
leaves = jax.tree_util.tree_leaves(Node(jnp.ones(3), "x")) # [array([1,1,1])]
Optional flatten_with_keys provides richer key metadata for tools like
jax.tree_util.tree_map_with_path:
s.register_pytree_type(
Node,
flatten=lambda obj: ([obj.data], obj.tag),
unflatten=lambda aux, children: Node(children[0], aux),
flatten_with_keys=lambda obj: (
[(jax.tree_util.GetAttrKey("data"), obj.data)],
obj.tag,
),
)
Optional serializer / deserializer pair teaches struct I/O how to
persist instances of this class in export / load round-trips when the
raw aux data is not directly JSON-serialisable:
s.register_pytree_type(
Node,
flatten=lambda obj: ([obj.data], obj.tag),
unflatten=lambda aux, children: Node(children[0], aux),
serializer=lambda obj: {"tag": obj.tag},
deserializer=lambda payload, children: Node(children[0], payload["tag"]),
)
register_attrs_type: attribute-based shorthand¶
register_attrs_type() is a higher-level
convenience that generates the flatten/unflatten functions automatically from
attribute names. Use this for attribute-backed classes when you do not want
to write the plumbing manually:
class Edge:
def __init__(self, flux, source, target):
self.flux = flux # array, pytree leaf
self.source = source # int, static metadata
self.target = target # int, static metadata
s.register_attrs_type(
Edge,
node_fields=("flux",),
static_fields=("source", "target"),
)
e = Edge(jnp.array(1.5), source=0, target=3)
# JAX now sees Edge as a pytree with one leaf (flux)
# and static aux (source, target).
An optional constructor overrides how instances are rebuilt from the
attribute map (useful when the class constructor has positional-only
arguments or special initialisation logic):
s.register_attrs_type(
Edge,
node_fields=("flux",),
static_fields=("source", "target"),
constructor=lambda vals: Edge(vals["flux"], vals["source"], vals["target"]),
)
Checking and resolving registrations¶
s.is_registered_pytree_type(Node) # True / False
# Resolve by class-reference string (used internally by the I/O layer):
spec = s.resolve_pytree_spec("mymodule:Node")
spec.cls # Node
spec.flatten # the flatten callable
spec.unflatten # the unflatten callable
The class registry¶
neuraLQX serialisation uses class references — strings in the format
"module.path:QualifiedClassName" — to identify types portably across
processes. These references are stored in manifests so that load can
reconstruct the correct class without hard-coding imports.
s.class_ref(TrainState)
# "neuralqx.my_module:TrainState"
s.resolve_class("neuralqx.my_module:TrainState")
# <class 'neuralqx.my_module.TrainState'>
You rarely need to call these directly. They are used internally by
to_state_dict, from_state_dict, export, and load. They
become relevant if you build custom I/O tooling on top of the struct layer or
if you need to verify that a class is importable under a given reference.
Resolution follows three steps:
In-memory process cache (fastest).
Lazy import via
importlib.Cache the result for future calls.
All Struct subclasses and all types registered via
register_pytree_type / register_attrs_type are automatically added to
the registry at registration time.
JAX pytree integration in practice¶
Because Struct subclasses are pytrees, they work transparently with all
JAX utilities:
import jax
import jax.numpy as jnp
class Params(s.Struct):
w: object
b: object
lr: float = s.field(static=True, default=1e-3)
p = Params(w=jnp.ones(4), b=jnp.zeros(4))
# tree operations
jax.tree_util.tree_leaves(p) # [w_array, b_array]
jax.tree_util.tree_map(jnp.zeros_like, p) # new Params with zero arrays
# jit: re-traces only when static field 'lr' changes
@jax.jit
def step(params, grad):
return jax.tree_util.tree_map(lambda p, g: p - params.lr * g, params, grad)
# grad: differentiates through node fields
def loss(params):
return jnp.sum(params.w ** 2)
g = jax.grad(loss)(p) # g is a Params with gradient arrays
# vmap: batches over node fields
batch = jax.tree_util.tree_map(lambda x: jnp.stack([x, x]), p) # "batch of 2"
jax.vmap(loss)(batch)
Static field behaviour under ``jit``:
Each unique combination of static field values is a separate jit
specialisation. If you change only a static field value and pass the same
function through jit, you will pay a re-compilation cost. This is
expected and correct, static fields encode things like layer sizes or string
keys that change the computation graph.
If a value changes frequently (e.g. a step counter), use a node field or pass it as a separate traced argument, not as a static field.
Opaque fields under ``jit``:
Opaque fields are preserved across pytree round-trips (JAX stores them in aux
data via an identity-preserving wrapper) but they are invisible to the tracer.
They are passed through as Python objects and must be the same Python object
(same id()) after unflattening for JAX cache consistency. Do not store
JAX arrays in opaque fields if you intend to differentiate through them.
StructABCMeta: abstract struct base classes¶
If you want to define an abstract struct (with @abc.abstractmethod
methods) use StructABCMeta as the metaclass
instead of the default StructMeta:
import abc
class AbstractSolver(s.Struct, metaclass=s.StructABCMeta):
lr: float = s.field(static=True)
@abc.abstractmethod
def step(self, params): ...
class SGD(AbstractSolver):
def step(self, params):
return jax.tree_util.tree_map(lambda p: p - self.lr, params)
StructABCMeta is the combined metaclass for StructMeta and
abc.ABCMeta and gives you both struct processing and ABC enforcement.
You cannot instantiate an abstract struct class directly. Any concrete
subclass that does not implement all abstract methods will raise
TypeError at instantiation time, exactly like a plain ABC.