neuralqx.utils.distributed.core module¶
Distributed runtime helpers based on JAX process semantics.
- class RuntimeInfo(backend, rank, size, local_rank, jax_available, multihost_utils_available)¶
Bases:
objectRuntime information for distributed execution.
- runtime_info()¶
Detect runtime backend and process identity.
Priority: - JAX process semantics when available. - Serial fallback otherwise.
- Return type:
- is_master()¶
Alias for master-process check.
Kept for migration compatibility with the previous MPI helpers.
- Return type:
- barrier(name='neuralqx:distributed:barrier')¶
Global synchronisation across all JAX processes when available.
Serial mode and missing multihost utilities degrade to a no-op.
- Return type:
- allgather(value)¶
Gather
valuefrom all processes.In serial mode this returns a 1-tuple
(value,).- Return type:
- bcast(value, *, root=0)¶
Broadcast
valuefromrootto all processes.Uses JAX multihost utils when available. Falls back to an allgather-based path.
- Return type:
- mpi_bcast(value, *, root=0)¶
Broadcast
valuefromrootto all processes.Uses JAX multihost utils when available. Falls back to an allgather-based path.
- Return type:
- mpi_any(value)¶
Logical OR across processes for a boolean-like value.
Returned type is a Python
bool.- Return type:
- mpi_sum_jax(x, *, token=None, communicator=None)¶
Compatibility helper: sum value/pytree across processes.
Returns
(reduced, token)to mirror previousmpi_sum_jaxsignature.
- mpi_mean_jax(x, *, token=None, communicator=None)¶
Compatibility helper: mean value/pytree across processes.
Returns
(reduced, token)to mirror previousmpi_mean_jaxsignature.
- check_distributed(return_dict=False, extended=True)¶
Print (or return) diagnostic information about the distributed runtime.
Works for both CPU-only and GPU-enabled environments.