neuralqx.utils.jax.sharding module

Some helper functions for sharding

get_abstract_mesh()

Return the active abstract mesh installed by NetKet/JAX.

NetKet 3.20+ uses a single mesh axis named ‘S’.

replicate_sharding(f_py)

Decorator replicating the sharding mechanism for jax.pure_callback-based get_conn_padded() implementations.

replicated_sharding()
sharded_sharding_1d()
normalize_pspec(pspec, ndim)
Return type:

tuple

derive_output_shardings(in_sharding, in_ndim)
Derive output shardings for:

xp : (batch, max_conn, n_sites) mels: (batch, max_conn)

assuming the input x has shape:

x : (batch, n_sites)

We preserve sharding of the batch axis and never shard the sites axis or the inserted max_conn axis.

is_distributed_array(x)
Return type:

bool

is_sharded_array(x)
Return type:

bool