neuralqx.vqs.mc.mc_state.state module

This file overrides some of the implementation of the MCState in NetKet

NOTE: part(s) of, or the entire content, of this file is obtained from NetKet’s source code

the original copyright mentioned above applies.

compute_chain_length(n_chains, n_samples)
check_chunk_size(n_samples, chunk_size)
jit_evaluate(fun, *args)

call fun(*args) inside of a jax.jit frame.

Parameters:
  • fun (Callable) – the hashable callable to be evaluated.

  • args – the arguments to the function.

class MCState(sampler, model=None, *, n_samples=None, n_samples_per_rank=None, n_discard_per_chain=None, chunk_size=None, variables=None, init_fun=None, apply_fun=None, seed=None, sampler_seed=None, mutable=False, training_kwargs={}, is_group_averaged=False)

Bases: VariationalState

Variational State for a Variational Neural Quantum State.

The state is sampled according to the provided sampler.

init(seed=None, dtype=None)

Initialises the variational parameters of the variational state.

property model: Module

Returns the model definition of this variational state.

When using model frameworks that encode the parameters directly into the model, such as equinox or flax.nnx, this will return the model including the parameters.

If you want access to the raw model without the parameters that is used internally by netket, use MCState._model instead.

property sampler: Sampler

The Monte Carlo sampler used by this Monte Carlo variational state.

property n_samples: int

The total number of samples generated at every sampling step.

property n_samples_per_rank: int

The number of samples generated on every JAX device or MPI rank at every sampling step.

property chain_length: int

Length of the markov chain used for sampling configurations.

If running under MPI, the total samples will be n_nodes * chain_length * n_batches.

property n_discard_per_chain: int

Number of discarded samples at the beginning of the markov chain.

property chunk_size: int

Suggested maximum size of the chunks used in forward and backward evaluations of the Neural Network model.

If your inputs are smaller than the chunk size this setting is ignored.

This can be used to lower the memory required to run a computation with a very high number of samples or on a very large lattice. Notice that inputs and outputs must still fit in memory, but the intermediate computations will now require less memory.

This option comes at an increased computational cost. While this cost should be negligible for large-enough chunk sizes, don’t use it unless you are memory bound!

This option is an hint: only some operations support chunking. If you perform an operation that is not implemented with chunking support, it will fall back to no chunking. To check if this happened, set the environment variable NETKET_DEBUG=1.

reset()

Resets the sampled states. This method is called automatically every time that the parameters/state is updated.

sample(*, chain_length=None, n_samples=None, n_discard_per_chain=None)

Sample a certain number of configurations.

If one among chain_length or n_samples is defined, that number of samples are generated. Otherwise the value set internally is used.

Parameters:
  • chain_length (int | None) – The length of the markov chains.

  • n_samples (int | None) – The total number of samples across all MPI ranks.

  • n_discard_per_chain (int | None) – Number of discarded samples at the beginning of the markov chain.

Return type:

Array

property samples: Array

Returns the set of cached samples.

The samples returned are guaranteed valid for the current state of the variational state. If no cached parameters are available, then they are sampled first and then cached.

To obtain a new set of samples either use reset() or sample().

log_value(σ)

Evaluate the variational state for a batch of states and returns the logarithm of the amplitude of the quantum state.

For pure states, this is \(\log(\langle\sigma|\psi\rangle)\), whereas for mixed states this is \(\log(\langle\sigma_r|\rho|\sigma_c\rangle)\), where \(\psi\) and \(\rho\) are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as σr=σ[::,0:N] and σc=σ[::,N:].

Given a batch of inputs (Nb, N), returns a batch of outputs (Nb,).

Return type:

Array

local_estimators(op, *, chunk_size=None)

Compute the local estimators for the operator op (also known as local energies when op is the Hamiltonian) at the current configuration samples self.samples.

\[O_\mathrm{loc}(s) = \frac{\langle s | \mathtt{op} | \psi \rangle}{\langle s | \psi \rangle}\]

Warning

The samples differ between MPI processes, so returned the local estimators will also take different values on each process. To compute sample averages and similar quantities, you will need to perform explicit operations over all MPI ranks. (Use functions like self.expect to get process-independent quantities without manual reductions.)

Parameters:
  • op (AbstractOperator) – The operator.

  • chunk_size (Optional[int]) – Suggested maximum size of the chunks used in forward and backward evaluations of the model. (Default: self.chunk_size)

expect(O)

Estimates the quantum expectation value for a given operator \(O\) or generic observable. In the case of a pure state \(\psi\) and an operator, this is \(\langle O\rangle= \langle \Psi|O|\Psi\rangle/\langle\Psi|\Psi\rangle\) otherwise for a mixed state \(\rho\), this is \(\langle O\rangle= \textrm{Tr}[\rho \hat{O}]/\textrm{Tr}[\rho]\).

Parameters:

O (Union[AbstractOperator, Sequence[AbstractOperator]]) – the operator or observable for which to compute the expectation value.

Return type:

Stats

Returns:

An estimation of the quantum expectation value \(\langle O\rangle\).

expect_and_grad(O, *, mutable=None, **kwargs)

Estimates the quantum expectation value and its gradient for a given operator \(O\).

Parameters:
  • O (Union[AbstractOperator, Sequence[AbstractOperator]]) – The operator \(O\) for which expectation value and gradient are computed.

  • mutable (Union[bool, str, Collection[str], DenyList, None]) – Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).

  • use_covariance – whether to use the covariance formula, usually reserved for hermitian operators, \(\textrm{Cov}[\partial\log\psi, O_{\textrm{loc}}\rangle]\)

Return type:

tuple[Stats, Any]

Returns:

An estimate of the quantum expectation value <O>. An estimate of the gradient of the quantum expectation value <O>.

expect_and_forces(O, *, mutable=None)

Estimates the quantum expectation value and the corresponding force vector for a given operator O.

The force vector \(F_j\) is defined as the covariance of log-derivative of the trial wave function and the local estimators of the operator. For complex holomorphic states, this is equivalent to the expectation gradient \(\frac{\partial\langle O\rangle}{\partial(\theta_j)^\star} = F_j\). For real-parameter states, the gradient is given by \(\frac{\partial\partial_j\langle O\rangle}{\partial\partial_j\theta_j} = 2 \textrm{Re}[F_j]\).

Parameters:
  • O (Union[AbstractOperator, Sequence[AbstractOperator]]) – The operator O for which expectation value and force are computed.

  • mutable (Union[bool, str, Collection[str], DenyList, None]) –

    Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).

Return type:

tuple[Stats, Any]

Returns:

An estimate of the quantum expectation value <O>. An estimate of the force vector \(F_j = \textrm{Cov}[\partial_j\log\psi, O_{\textrm{loc}}]\).

quantum_geometric_tensor(qgt_T=None)

Computes an estimate of the quantum geometric tensor G_ij. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix.

Parameters:

qgt_T (Optional[LinearOperator]) – the optional type of the quantum geometric tensor. By default it’s automatically selected.

Returns:

A linear operator representing the quantum geometric tensor.

Return type:

nk.optimizer.LinearOperator

to_array(normalize=True)

Returns the dense-vector representation of this state.

Parameters:

normalize (bool) – If True, the vector is normalized to have L2-norm 1.

Return type:

Array

Returns:

An exponentially large vector representing the state in the computational basis.

to_array_numpy(normalize=True)
Return type:

ndarray

project(wrap_model, *, reuse_cached_samples=True, **wrapper_kwargs)

Generic model-projection method.

Wraps the current ansatz with an arbitrary wrap_model(self.model, **kwargs) and returns a new MCState whose variables contain the trained parameters under the wrapper’s “base” scope.

Parameters:
  • wrap_model (Callable[..., Any]) – A function that projects the model in any way desired, i.e. wrap_model(model, **kwargs) -> wrapped_model.

  • reuse_cached_samples (bool) – If True and cached samples are available, reuse them. This is safe for symmetry projections that do not alter sampling distributions.

  • wrapper_kwargs – Extra keyword arguments forwarded to wrap_model().

Return type:

MCState

Returns:

A new MCState with the wrapped model and transplanted parameters.

to_group_averaged(*, symmetries=None, graph=None, index_perms=None, characters=None, reuse_cached_samples=True)

Return a NEW MCState that evaluates the group-projected wavefunction built on top of the trained vanilla model parameters, reusing the SAME sampler, SamplerState, n_samples, n_discard_per_chain, chunk_size, and mutables policy.

Return type:

MCState

local_estimators(state, op, *, chunk_size)
serialize_MCState(vstate)
deserialize_MCState(vstate, state_dict, force_load_mpi=False)
cast_MCState(vstate, state_dict, n_samples)