neuralqx.utils.distributed.io module

block_until_ready_tree(value)

Block on device work for all leaves exposing block_until_ready.

Return type:

Any

device_get_tree(value)

Convert all JAX device-backed leaves to host-backed values.

Return type:

Any

safe_replicate_for_io(value, *, replicate_to_all_processes=False, root=0, block_until_ready=True)

Safely prepare an arbitrary pytree for host-side I/O.

Steps: - Optionally block device execution. - device_get to host memory. - Optionally broadcast host payload to all processes.

Return type:

Any