neuralqx.experimental.vqs.mc.mc_state.mtmh_state module

class MultiMCState(states)

Bases: object

Container for multiple MCState objects with overlap diagnostics.

This class groups several independently sampled variational Monte Carlo states that share the same Hilbert space, and provides tools to compute pairwise overlap measures such as the fidelity matrix and derived quantities (overlap magnitude and orthogonality).

All contained states must act on the same Hilbert space, a mismatch raises ValueError.

Parameters are managed per-state, no parameters are shared unless you explicitly tie them outside of this container.

Parameters:

states (list[MCState]) – List of Monte Carlo variational states to manage.

Raises:

ValueError – If states is empty or if any pair of states has a different Hilbert space.

init_parameters(init_fun=None, *, seed=None)
property n_states: int

Number of contained states.

Returns:

The number of managed MCState instances.

property n_samples: list[int]

Return the number of samples for each state.

Returns:

a list containing the number of samples for each state.

property sampler: list[Sampler]

Return the sampler object for each state.

Returns:

a list containing the sampler object for each state.

property model: list[Any] | None

Returns the model definition for each state.

property n_samples_per_rank: list[int]

The number of samples generated per state on each MPI rank at each sampling step.

property chain_length: list[int]

Length of the Markov chain used for sampling configurations for each state.

Note: If running with MPI, the total samples will be n_nodes * chain_length * n_batches

property n_discard_per_chain: list[int]

Number of discarded samples at the beginning of the chain for each state.

sample(*, chain_length=None, n_samples=None, n_discard_per_chain=None)

Sample a certain number of configurations for each state.

If one among chain_length or n_samples is defined, that number of samples are generated. Otherwise the value set internally is used.

Parameters:
  • chain_length (Optional[int]) – the length of the Markov chain used for sampling.

  • n_samples (Optional[int]) – the total number of samples across all MPI ranks for each state.

  • n_discard_per_chain (Optional[int]) – number of discarded samples at the beginning of the chain for each state.

property samples: list[Array]

Return the set of cached samples for each state.

log_value(σ)

Evaluate the each variational state for a batch of states and returns the logarithm of the amplitude of the quantum state.

For pure states, this is \(\log(\langle\sigma|\psi\rangle)\), whereas for mixed states this is \(\log(\langle\sigma_r|\rho|\sigma_c\rangle)\), where \(\psi\) and \(\rho\) are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as σr=σ[::,0:N] and σc=σ[::,N:].

Given a batch of inputs (Nb, N), returns a batch of outputs (Nb,).

Return type:

list[Array]

local_estimators(op, *, chunk_size=None)

Compute the local estimators for the operator op (also known as local energies when op is the Hamiltonian) at the current configuration samples self.samples for each state in the self.states.

\[O_\mathrm{loc}(s) = \frac{\langle s | \mathtt{op} | \psi \rangle}{\langle s | \psi \rangle}\]

Warning

The samples differ between MPI processes, so returned the local estimators will also take different values on each process. To compute sample averages and similar quantities, you will need to perform explicit operations over all MPI ranks. (Use functions like self.expect to get process-independent quantities without manual reductions.)

Parameters:
  • op (AbstractOperator) – The operator.

  • chunk_size (Optional[int]) – Suggested maximum size of the chunks used in forward and backward evaluations of the model. (Default: self.chunk_size)

expect_and_grad(O, *, mutable=None, **kwargs)

Estimates the quantum expectation value and its gradient for a given operator \(O\) for each state in the self.states.

Parameters:
  • O (AbstractOperator) – The operator \(O\) for which expectation value and gradient are computed.

  • mutable (Optional[CollectionFilter]) – Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).

  • use_covariance – whether to use the covariance formula, usually reserved for hermitian operators, \(\textrm{Cov}[\partial\log\psi, O_{\textrm{loc}}\rangle]\)

Return type:

list[tuple[Stats, PyTree]]

Returns: a list of
  • estimates of the quantum expectation value <O>.

  • estimates of the gradient of the quantum expectation value <O>.

expect_and_forces(O, *, mutable=None)

Estimates the quantum expectation value and the corresponding force vector for a given operator O for each state in the self.states.

The force vector \(F_j\) is defined as the covariance of log-derivative of the trial wave function and the local estimators of the operator. For complex holomorphic states, this is equivalent to the expectation gradient \(\frac{\partial\langle O\rangle}{\partial(\theta_j)^\star} = F_j\). For real-parameter states, the gradient is given by \(\frac{\partial\partial_j\langle O\rangle}{\partial\partial_j\theta_j} = 2 \textrm{Re}[F_j]\).

Parameters:
  • O (AbstractOperator) – The operator O for which expectation value and force are computed.

  • mutable (Optional[CollectionFilter]) –

    Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).

Return type:

list[tuple[Stats, PyTree]]

Returns: a list of
  • estimates of the quantum expectation value <O>.

  • estimates of the force vector

\(F_j = \textrm{Cov}[\partial_j\log\psi, O_{\textrm{loc}}]\).

quantum_geometric_tensor(qgt_T=None)

Computes an estimate of the quantum geometric tensor G_ij for each state in the self.states. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix.

Return type:

list[LinearOperator]

to_array(normalize=True, *, state=None)
Return type:

Union[list[Array], Array]

project(wrap_model, *, reuse_cached_samples=True, **wrapper_kwargs)

Generic model-projection method.

Wraps the current ansatz with an arbitrary wrap_model(self.model, **kwargs) and returns a new MultiMCState whose variables contain the trained parameters under the wrapper’s “base” scope.

This wraps each state in the states saved in this container.

Parameters:
  • wrap_model (Callable[..., Any]) – A function that projects the model in any way desired, i.e. wrap_model(model, **kwargs) -> wrapped_model.

  • reuse_cached_samples (bool) – If True and cached samples are available, reuse them. This is safe for symmetry projections that do not alter sampling distributions.

  • wrapper_kwargs – Extra keyword arguments forwarded to wrap_model().

Return type:

MultiMCState

Returns:

A new MCState with the wrapped model and transplanted parameters.

to_group_averaged(*, symmetries=None, graph=None, index_perms=None, characters=None, reuse_cached_samples=True)

Return a NEW MultiMCState that evaluates the group-projected wavefunction built on top of the trained vanilla model parameters, reusing the SAME sampler, SamplerState, n_samples, n_discard_per_chain, chunk_size, and mutables policy.

Return type:

MultiMCState

property hilbert

Shared Hilbert space.

All contained states are required to have this same Hilbert space.

Returns:

The common Hilbert space object.

property parameters: list[Any]

Parameters of all contained states.

Returns:

A list of parameter pytrees, one per state, in the same order as states.

reset()

Reset all contained states.

This typically discards cached samples and statistics so that subsequent estimators use fresh Monte Carlo data generated by each state’s sampler.

Return type:

None

expect(O)

Compute expectation values of an operator/observable on all states.

This is a convenience wrapper around calling state.expect(O) on each contained state.

Parameters:

O – Operator or observable to evaluate.

Return type:

list[Stats]

Returns:

List of statistics objects, one per state, as returned by each state’s expect.

fidelity_matrix(*, resample=False, return_sigma=False, assume_machine_pow_2=True)

Estimate the pairwise fidelity matrix between all contained states.

The returned matrix F has entries F[i, j] estimating the magnitude-squared overlap between state i and state j. The diagonal is set to 1 by construction.

For machine_pow == 2 in both samplers, the estimator corresponds to the normalised pure-state fidelity

\[F_{ij} = \frac{|\langle \psi_i | \psi_j \rangle|^2} {\langle \psi_i | \psi_i \rangle\,\langle \psi_j | \psi_j \rangle}.\]

Uncertainty handling

When available, the returned sigma values are taken from the NetKet statistics produced by netket.jax.expect(). These are only as meaningful as the chain structure information provided to the estimator.

Warning

The standard quantum fidelity interpretation is exact only for machine_pow == 2. If assume_machine_pow_2 is True and any state uses a different machine_pow, this method raises ValueError.

type resample:

bool

param resample:

If True, call reset() before computing the matrix.

type return_sigma:

bool

param return_sigma:

If True, also return an estimated standard error matrix dF.

type assume_machine_pow_2:

bool

param assume_machine_pow_2:

Enforce machine_pow == 2 for all states.

rtype:

Union[ndarray, Tuple[ndarray, ndarray]]

return:

If return_sigma is False, returns F with shape (n_states, n_states). If return_sigma is True, returns (F, dF) where dF is an error estimate with the same shape.

raises ValueError:

If assume_machine_pow_2 is True and any state uses machine_pow != 2.

overlap_matrix(*, kind='fidelity', resample=False, return_sigma=False, assume_machine_pow_2=True)

Compute an overlap-style matrix derived from the fidelity matrix.

This is a lightweight post-processing wrapper around fidelity_matrix().

Supported values of kind are:

  • "fidelity": returns F[i, j].

  • "overlap": returns sqrt(F[i, j]) (magnitude of overlap).

  • "orthogonality": returns 1 - F[i, j].

If return_sigma is True and kind == "overlap", a simple error propagation is used:

\[\sigma_{\sqrt{F}} \approx \frac{\sigma_F}{2\sqrt{F}}.\]
Parameters:
  • kind (str) – Which matrix to return: "fidelity", "overlap", or "orthogonality".

  • resample (bool) – If True, call reset() before computing.

  • return_sigma (bool) – If True, also return an estimated standard error matrix.

  • assume_machine_pow_2 (bool) – Forwarded to fidelity_matrix().

Return type:

Union[ndarray, Tuple[ndarray, ndarray]]

Returns:

The requested matrix M (and optionally its uncertainty dM).

Raises:

ValueError – If kind is not one of the supported options.

print_overlap_matrix(*, kind='fidelity', resample=False, digits=3, show_sigma=False, assume_machine_pow_2=True)

Pretty-print an overlap-style matrix on the master process.

The matrix is computed on all ranks (to keep any potential collectives consistent), but printing is performed only on the global master as determined by _is_master().

Parameters:
  • kind (str) – Which matrix to print: "fidelity", "overlap", or "orthogonality".

  • resample (bool) – If True, call reset() before computing.

  • digits (int) – Number of decimal digits used for formatting.

  • show_sigma (bool) – If True, also print uncertainty estimates when available.

  • assume_machine_pow_2 (bool) – Forwarded to overlap_matrix().

Return type:

None