neuralqx.experimental.utils.struct.fields module

Public field declarations for neuralqx.utils.struct.

This module defines the declarative schema language for Struct classes.

At a high level:
  • FieldKind determines how values participate in JAX pytree semantics.

  • FieldSpec stores normalized, immutable metadata for a single field.

  • field() is the primary user API for declaring field behaviour.

Struct fields are split into three runtime categories:

  1. Node fields (FieldKind.NODE) These are dynamic pytree leaves. JAX traces/transforms these values.

  2. Static fields (FieldKind.STATIC) These become pytree auxiliary metadata, not leaves. They influence JIT cache identity, so values must be hashable and cannot contain array-like objects.

  3. Opaque fields (FieldKind.OPAQUE) These are intentionally excluded from pytree leaves and are treated as runtime side-data. They are useful for handles, caches, tokens, or process-local objects that should not be traced.

The derived=... hook defines computed fields that are recomputed after construction/replacement/deserialization. Derived fields:

  • must use init=False,

  • cannot define defaults,

  • cannot be node fields,

  • are never serialised (they are recomputed instead).

class FieldKind(value)

Bases: str, Enum

Enum describing a field’s role in Struct/pytree behaviour.

class FieldSpec(name, default=MISSING, kind=FieldKind.NODE, factory=MISSING, init=True, repr=True, compare=True, serialize=None, kw_only=False, doc=None, metadata=<factory>, converter=None, validators=(), derived=None)

Bases: object

Immutable metadata describing one struct field.

FieldSpec is the normalised internal representation used by the struct class processor. In user code, these are typically created by field().

  1. You write class annotations/defaults/field(...) declarations.

  2. Struct metaclass collects declarations and builds FieldSpec objects.

  3. Runtime processing validates and uses specs to drive constructor, pytree, equality, repr, and serialisation behaviour.

Notes

FieldSpec instances are normally created by field(). Advanced users may construct them directly when building custom decorators.

field(*, static=False, pytree=True, default=MISSING, default_factory=MISSING, init=True, repr=True, compare=True, serialize=None, kw_only=False, doc=None, metadata=None, converter=None, validator=None, derived=None)

Declare metadata for one Struct field.

This function mirrors the ergonomics of dataclasses.field() while adding pytree-aware semantics and runtime hooks (converter, validator, derived).

Typical usage

class State(Struct):
    x: jax.Array
    tag: str = field(static=True, default="exp-A")
    token: object = field(pytree=False, default_factory=object)
    norm: float = field(
        static=True,
        init=False,
        derived=lambda self: float(np.asarray(self.x).sum()),
    )

The shape intentionally resembles dataclasses.field(), but it adds JAX/Struct-specific controls:

  • static and pytree decide field role.

  • converter and validator add runtime hooks.

  • derived supports deterministic computed fields.

type static:

bool

param static:

Mark this field as static pytree metadata. Static values are part of JAX cache keys, so they must be hashable and array-free.

type pytree:

bool

param pytree:

When False, the field is opaque runtime data (not a pytree leaf).

type default:

Any

param default:

Concrete default value used when input is omitted.

type default_factory:

Callable[[], Any] | MissingType

param default_factory:

Zero-argument callable used to lazily create defaults.

type init:

bool

param init:

Include field in generated constructor.

type repr:

bool

param repr:

Include field in __repr__ output.

type compare:

bool

param compare:

Include field in __eq__ checks.

type serialize:

bool | None

param serialize:

Override whether field participates in state dict/export I/O.

type kw_only:

bool

param kw_only:

Make field keyword-only in generated constructor.

type doc:

str | None

param doc:

Optional per-field documentation string.

type metadata:

Mapping[str, Any] | None

param metadata:

Free-form immutable metadata mapping for tooling.

type converter:

Callable[..., Any] | None

param converter:

Optional converter called on assignment paths. Accepts either converter(value) or converter(self, value).

type validator:

Union[Callable[..., Any], Sequence[Callable[..., Any]], None]

param validator:

Optional validator (or sequence of validators). Validators accept either validator(value) or validator(self, value). Returning False raises ValidationError.

type derived:

Callable[..., Any] | None

param derived:

Optional callable for computed fields (must use init=False). Accepts derived() or derived(self) and is recomputed after construction, replacement, and deserialisation.

rtype:

FieldSpec

returns:

Frozen metadata object consumed by struct class processing.

raises ValueError:

If declaration invariants are violated, for example: - both default and default_factory are provided, - static=True and pytree=False are combined, - derived-field constraints are broken.

Notes

Validators may either:
  • raise exceptions directly, or

  • return False to trigger a ValidationError.

Converters/validators support ergonomic signatures:
  • fn(value)

  • fn(self, value)