neuralqx.utils.distributed package

JAX-process distributed helpers for neuraLQX.

This package is the migration target for code that previously depended on neuralqx.utils.mpi.

class RuntimeInfo(backend, rank, size, local_rank, jax_available, multihost_utils_available)

Bases: object

Runtime information for distributed execution.

backend: str
rank: int
local_rank: int
runtime_info()

Detect runtime backend and process identity.

Priority: - JAX process semantics when available. - Serial fallback otherwise.

Return type:

RuntimeInfo

process_index()

Global process index (rank-equivalent).

Return type:

int

process_count()

Global process count (world-size equivalent).

Return type:

int

is_global_master()

Return True when this process is global rank 0.

Return type:

bool

is_master()

Alias for master-process check.

Kept for migration compatibility with the previous MPI helpers.

Return type:

bool

barrier(name='neuralqx:distributed:barrier')

Global synchronisation across all JAX processes when available.

Serial mode and missing multihost utilities degrade to a no-op.

Return type:

None

allgather(value)

Gather value from all processes.

In serial mode this returns a 1-tuple (value,).

Return type:

Any

bcast(value, *, root=0)

Broadcast value from root to all processes.

Uses JAX multihost utils when available. Falls back to an allgather-based path.

Return type:

Any

mpi_bcast(value, *, root=0)

Broadcast value from root to all processes.

Uses JAX multihost utils when available. Falls back to an allgather-based path.

Return type:

Any

mpi_any(value)

Logical OR across processes for a boolean-like value.

Returned type is a Python bool.

Return type:

bool

mpi_sum_jax(x, *, token=None, communicator=None)

Compatibility helper: sum value/pytree across processes.

Returns (reduced, token) to mirror previous mpi_sum_jax signature.

mpi_mean_jax(x, *, token=None, communicator=None)

Compatibility helper: mean value/pytree across processes.

Returns (reduced, token) to mirror previous mpi_mean_jax signature.

print0(*args, **kwargs)

Print from global process 0 only.

Return type:

None

detect_cpus_per_task()
Return type:

int

check_distributed(return_dict=False, extended=True)

Print (or return) diagnostic information about the distributed runtime.

Works for both CPU-only and GPU-enabled environments.

Return type:

dict[str, str] | None

get_distributed_info_dict()

Small serialisable runtime descriptor for checkpoint metadata.

Return type:

dict[str, Any]

block_until_ready_tree(value)

Block on device work for all leaves exposing block_until_ready.

Return type:

Any

device_get_tree(value)

Convert all JAX device-backed leaves to host-backed values.

Return type:

Any

safe_replicate_for_io(value, *, replicate_to_all_processes=False, root=0, block_until_ready=True)

Safely prepare an arbitrary pytree for host-side I/O.

Steps: - Optionally block device execution. - device_get to host memory. - Optionally broadcast host payload to all processes.

Return type:

Any

Submodules