neuralqx.vqs.mc.mc_state.utils module

to_array_numpy(hilbert, apply_fun, variables, *, normalize=True, chunk_size=None)

NumPy-backed variant of to_array:

ψ[i] = exp(apply_fun(variables, basis_state_i))

for all basis states of hilbert, evaluated in JAX on chunks, but stored directly in a NumPy array of length hilbert.n_states.

Differences from the original JAX to_array:
  • Returns a NumPy ndarray (not a JAX DeviceArray).

  • Internally processes states in chunks and never creates a giant JAX array of shape (n_states,).

  • Currently supports only single-process (_dist.n_nodes == 1) and no experimental sharding.

It works for both real and complex amplitudes: dtype is inferred from the first chunk.

Return type:

ndarray