neuralqx.utils.jax.sharding module¶
Some helper functions for sharding
- get_abstract_mesh()¶
Return the active abstract mesh installed by NetKet/JAX.
NetKet 3.20+ uses a single mesh axis named ‘S’.
- replicate_sharding(f_py)¶
Decorator replicating the sharding mechanism for jax.pure_callback-based get_conn_padded() implementations.
- replicated_sharding()¶
- sharded_sharding_1d()¶
- derive_output_shardings(in_sharding, in_ndim)¶
- Derive output shardings for:
xp : (batch, max_conn, n_sites) mels: (batch, max_conn)
- assuming the input x has shape:
x : (batch, n_sites)
We preserve sharding of the batch axis and never shard the sites axis or the inserted max_conn axis.