neuralqx.experimental.vqs.mc.mc_state.mtmh_state module¶
- class MultiMCState(states)¶
Bases:
objectContainer for multiple
MCStateobjects with overlap diagnostics.This class groups several independently sampled variational Monte Carlo states that share the same Hilbert space, and provides tools to compute pairwise overlap measures such as the fidelity matrix and derived quantities (overlap magnitude and orthogonality).
All contained states must act on the same Hilbert space, a mismatch raises
ValueError.Parameters are managed per-state, no parameters are shared unless you explicitly tie them outside of this container.
- Parameters:
states (
list[MCState]) – List of Monte Carlo variational states to manage.- Raises:
ValueError – If
statesis empty or if any pair of states has a different Hilbert space.
- init_parameters(init_fun=None, *, seed=None)¶
- property n_states: int¶
Number of contained states.
- Returns:
The number of managed
MCStateinstances.
- property n_samples: list[int]¶
Return the number of samples for each state.
- Returns:
a list containing the number of samples for each state.
- property sampler: list[Sampler]¶
Return the sampler object for each state.
- Returns:
a list containing the sampler object for each state.
- property n_samples_per_rank: list[int]¶
The number of samples generated per state on each MPI rank at each sampling step.
- property chain_length: list[int]¶
Length of the Markov chain used for sampling configurations for each state.
Note: If running with MPI, the total samples will be n_nodes * chain_length * n_batches
- property n_discard_per_chain: list[int]¶
Number of discarded samples at the beginning of the chain for each state.
- sample(*, chain_length=None, n_samples=None, n_discard_per_chain=None)¶
Sample a certain number of configurations for each state.
If one among chain_length or n_samples is defined, that number of samples are generated. Otherwise the value set internally is used.
- log_value(σ)¶
Evaluate the each variational state for a batch of states and returns the logarithm of the amplitude of the quantum state.
For pure states, this is \(\log(\langle\sigma|\psi\rangle)\), whereas for mixed states this is \(\log(\langle\sigma_r|\rho|\sigma_c\rangle)\), where \(\psi\) and \(\rho\) are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as
σr=σ[::,0:N]andσc=σ[::,N:].Given a batch of inputs
(Nb, N), returns a batch of outputs(Nb,).- Return type:
list[Array]
- local_estimators(op, *, chunk_size=None)¶
Compute the local estimators for the operator
op(also known as local energies whenopis the Hamiltonian) at the current configuration samplesself.samplesfor each state in theself.states.\[O_\mathrm{loc}(s) = \frac{\langle s | \mathtt{op} | \psi \rangle}{\langle s | \psi \rangle}\]Warning
The samples differ between MPI processes, so returned the local estimators will also take different values on each process. To compute sample averages and similar quantities, you will need to perform explicit operations over all MPI ranks. (Use functions like
self.expectto get process-independent quantities without manual reductions.)- Parameters:
op (
AbstractOperator) – The operator.chunk_size (
Optional[int]) – Suggested maximum size of the chunks used in forward and backward evaluations of the model. (Default:self.chunk_size)
- expect_and_grad(O, *, mutable=None, **kwargs)¶
Estimates the quantum expectation value and its gradient for a given operator \(O\) for each state in the
self.states.- Parameters:
O (AbstractOperator) – The operator \(O\) for which expectation value and gradient are computed.
mutable (Optional[CollectionFilter]) – Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).
use_covariance – whether to use the covariance formula, usually reserved for hermitian operators, \(\textrm{Cov}[\partial\log\psi, O_{\textrm{loc}}\rangle]\)
- Return type:
list[tuple[Stats, PyTree]]
- Returns: a list of
estimates of the quantum expectation value <O>.
estimates of the gradient of the quantum expectation value <O>.
- expect_and_forces(O, *, mutable=None)¶
Estimates the quantum expectation value and the corresponding force vector for a given operator O for each state in the
self.states.The force vector \(F_j\) is defined as the covariance of log-derivative of the trial wave function and the local estimators of the operator. For complex holomorphic states, this is equivalent to the expectation gradient \(\frac{\partial\langle O\rangle}{\partial(\theta_j)^\star} = F_j\). For real-parameter states, the gradient is given by \(\frac{\partial\partial_j\langle O\rangle}{\partial\partial_j\theta_j} = 2 \textrm{Re}[F_j]\).
- Parameters:
O (AbstractOperator) – The operator O for which expectation value and force are computed.
mutable (Optional[CollectionFilter]) –
Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).
- Return type:
list[tuple[Stats, PyTree]]
- Returns: a list of
estimates of the quantum expectation value <O>.
estimates of the force vector
\(F_j = \textrm{Cov}[\partial_j\log\psi, O_{\textrm{loc}}]\).
- quantum_geometric_tensor(qgt_T=None)¶
Computes an estimate of the quantum geometric tensor G_ij for each state in the
self.states. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix.- Return type:
list[LinearOperator]
- project(wrap_model, *, reuse_cached_samples=True, **wrapper_kwargs)¶
Generic model-projection method.
Wraps the current ansatz with an arbitrary wrap_model(self.model, **kwargs) and returns a new MultiMCState whose variables contain the trained parameters under the wrapper’s “base” scope.
This wraps each state in the states saved in this container.
- Parameters:
wrap_model (
Callable[...,Any]) – A function that projects the model in any way desired, i.e. wrap_model(model, **kwargs) -> wrapped_model.reuse_cached_samples (
bool) – If True and cached samples are available, reuse them. This is safe for symmetry projections that do not alter sampling distributions.wrapper_kwargs – Extra keyword arguments forwarded to wrap_model().
- Return type:
- Returns:
A new MCState with the wrapped model and transplanted parameters.
- to_group_averaged(*, symmetries=None, graph=None, index_perms=None, characters=None, reuse_cached_samples=True)¶
Return a NEW MultiMCState that evaluates the group-projected wavefunction built on top of the trained vanilla model parameters, reusing the SAME sampler, SamplerState, n_samples, n_discard_per_chain, chunk_size, and mutables policy.
- Return type:
- property hilbert¶
Shared Hilbert space.
All contained states are required to have this same Hilbert space.
- Returns:
The common Hilbert space object.
- property parameters: list[Any]¶
Parameters of all contained states.
- Returns:
A list of parameter pytrees, one per state, in the same order as
states.
- reset()¶
Reset all contained states.
This typically discards cached samples and statistics so that subsequent estimators use fresh Monte Carlo data generated by each state’s sampler.
- Return type:
- expect(O)¶
Compute expectation values of an operator/observable on all states.
This is a convenience wrapper around calling
state.expect(O)on each contained state.- Parameters:
O – Operator or observable to evaluate.
- Return type:
list[Stats]- Returns:
List of statistics objects, one per state, as returned by each state’s
expect.
- fidelity_matrix(*, resample=False, return_sigma=False, assume_machine_pow_2=True)¶
Estimate the pairwise fidelity matrix between all contained states.
The returned matrix
Fhas entriesF[i, j]estimating the magnitude-squared overlap between state i and state j. The diagonal is set to1by construction.For
machine_pow == 2in both samplers, the estimator corresponds to the normalised pure-state fidelity\[F_{ij} = \frac{|\langle \psi_i | \psi_j \rangle|^2} {\langle \psi_i | \psi_i \rangle\,\langle \psi_j | \psi_j \rangle}.\]Uncertainty handling¶
When available, the returned sigma values are taken from the NetKet statistics produced by
netket.jax.expect(). These are only as meaningful as the chain structure information provided to the estimator.Warning
The standard quantum fidelity interpretation is exact only for
machine_pow == 2. Ifassume_machine_pow_2isTrueand any state uses a differentmachine_pow, this method raisesValueError.- type resample:
- param resample:
If
True, callreset()before computing the matrix.- type return_sigma:
- param return_sigma:
If
True, also return an estimated standard error matrixdF.- type assume_machine_pow_2:
- param assume_machine_pow_2:
Enforce
machine_pow == 2for all states.- rtype:
- return:
If
return_sigmaisFalse, returnsFwith shape(n_states, n_states). Ifreturn_sigmaisTrue, returns(F, dF)wheredFis an error estimate with the same shape.- raises ValueError:
If
assume_machine_pow_2isTrueand any state usesmachine_pow != 2.
- overlap_matrix(*, kind='fidelity', resample=False, return_sigma=False, assume_machine_pow_2=True)¶
Compute an overlap-style matrix derived from the fidelity matrix.
This is a lightweight post-processing wrapper around
fidelity_matrix().Supported values of
kindare:"fidelity": returnsF[i, j]."overlap": returnssqrt(F[i, j])(magnitude of overlap)."orthogonality": returns1 - F[i, j].
If
return_sigmaisTrueandkind == "overlap", a simple error propagation is used:\[\sigma_{\sqrt{F}} \approx \frac{\sigma_F}{2\sqrt{F}}.\]- Parameters:
- Return type:
- Returns:
The requested matrix
M(and optionally its uncertaintydM).- Raises:
ValueError – If
kindis not one of the supported options.
- print_overlap_matrix(*, kind='fidelity', resample=False, digits=3, show_sigma=False, assume_machine_pow_2=True)¶
Pretty-print an overlap-style matrix on the master process.
The matrix is computed on all ranks (to keep any potential collectives consistent), but printing is performed only on the global master as determined by
_is_master().- Parameters:
kind (
str) – Which matrix to print:"fidelity","overlap", or"orthogonality".digits (
int) – Number of decimal digits used for formatting.show_sigma (
bool) – IfTrue, also print uncertainty estimates when available.assume_machine_pow_2 (
bool) – Forwarded tooverlap_matrix().
- Return type: