neuralqx.vqs.mc.kernels module¶
This file overrides some of the implementation of some common kernels used by MCState and MCMixedState in NetKet
- NOTE: part(s) of, or the entire content, of this file is obtained from NetKet’s source code
the original copyright mentioned above applies.
- batch_discrete_kernel(kernel)¶
Batch a decorator that only works with 1 sample so that it works with a batch of samples.
Works only for discrete-kernels who take two args as inputs
- local_value_kernel(logpsi, pars, σ, args)¶
local_value kernel for MCState and generic operators
- local_value_kernel_jax(logpsi, pars, σ, O)¶
local_value kernel for MCState for jax-compatible operators
- local_value_kernel_factored(logpsi_σ, logpsi, pars, σ, args)¶
Bra-action local estimator using precomputed logpsi(pars, σ).
- local_value_kernel_jax_factored(logpsi_σ, logpsi, pars, σ, O)¶
JAX-operator bra-action local estimator using precomputed logpsi(pars, σ).
- local_value_squared_kernel(logpsi, pars, σ, args)¶
local_value kernel for MCState and Squared (generic) operators
- local_value_kernel_squared_jax(logpsi, pars, σ, O)¶
Squared‐operator local estimator for any DiscreteJaxOperator
- local_value_squared_kernel_factored(logpsi_σ, logpsi, pars, σ, args)¶
- local_value_kernel_squared_jax_factored(logpsi_σ, logpsi, pars, σ, O)¶
- local_value_op_op_cost(logpsi, pars, σ, args)¶
local_value kernel for MCState and generic operators
- local_value_kernel_chunked(logpsi, pars, σ, args, *, chunk_size=None)¶
local_value kernel for MCState and generic operators
- local_value_kernel_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)¶
Chunked bra-action estimator using precomputed logpsi(pars, σ).
- local_value_kernel_jax_conn_chunked(logpsi, pars, σ, O, chunk_size)¶
local_value kernel for MCState for jax-compatible operators
- local_value_kernel_jax_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)¶
Chunked JAX-operator bra-action estimator using precomputed logpsi(pars, σ).
- local_value_squared_kernel_chunked(logpsi, pars, σ, args, *, chunk_size=None)¶
local_value kernel for MCState and Squared (generic) operators
- local_value_squared_kernel_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)¶
- local_value_kernel_squared_jax_conn_chunked(logpsi, pars, σ, O, chunk_size)¶
Squared version with the same splitting strategy
- local_value_kernel_squared_jax_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)¶
- local_value_op_op_cost_chunked(logpsi, pars, σ, args, *, chunk_size=None)¶
local_value kernel for MCMixedState and generic operators
- local_value_kernel_jax_chunked(logpsi, pars, σ, O, *, chunk_size=None)¶
local_value kernel for MCState and jaxcoompatible operators
- local_value_kernel_jax_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)¶
- local_value_kernel_penalty_cost(logpsi, pars, σ, args)¶
local_value kernel for MCState and generic operators
- local_value_kernel_penalty_cost_factored(logpsi_σ, logpsi, pars, σ, args)¶
- local_value_kernel_penalty_cost_chunked(logpsi, pars, σ, args, *, chunk_size=None)¶
Just the regular kernel for a local operator for now
- local_value_kernel_penalty_cost_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)¶
- local_value_kernel_variance(logpsi, pars, σ, args)¶
Local estimator for a VarianceObservable.
args = (local_kernel_O, args_O, local_kernel_O2, args_O2)
Returns L_σ = O²_loc(σ) − ⟨O⟩_MC² (real part only).
The mean ⟨O⟩ is taken over the same mini-batch σ that NetKet’s Monte-Carlo integration is using, which is the usual unbiased MC estimator employed in the original NetKet routine.
- volume_cost_kernel(logpsi, inner_kernel, pars, σ, inner_args, fprime, g)¶
Jitted per-sample local estimator for the InverseExpectationCost type
C_loc(σ) = f’(<V>) * V_loc(σ) + g
here logpsi and inner_kernel are static (functions) and inner_args, fprime, g are non-static runtime values (device arrays).
This ensures a single stable compiled executable regardless of how often you re-enter training, while still allowing <V> to change batch-to-batch
- Return type:
Array
- local_value_kernel_ket_action(logpsi, pars, σ, args)¶
local_value kernel for MCState and generic operators
- local_value_kernel_jax_ket_action(logpsi, pars, σ, O)¶
Same as local_value_kernel_ket_action, but for DiscreteJaxOperator-style padded connections.
- local_value_kernel_ket_action_factored(logpsi_σ, logpsi, pars, σ, args)¶
Ket-action local estimator using precomputed logpsi(pars, σ).
- local_value_kernel_jax_ket_action_factored(logpsi_σ, logpsi, pars, σ, O)¶
JAX ket-action local estimator using precomputed logpsi(pars, σ).
- local_value_squared_kernel_ket_action(logpsi, pars, σ, args)¶
- local_value_kernel_squared_jax_ket_action(logpsi, pars, σ, O)¶
- local_value_squared_kernel_ket_action_factored(logpsi_σ, logpsi, pars, σ, args)¶
- local_value_kernel_squared_jax_ket_action_factored(logpsi_σ, logpsi, pars, σ, O)¶
- local_value_kernel_ket_action_chunked(logpsi, pars, σ, args, *, chunk_size=None)¶
- local_value_kernel_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)¶
- local_value_kernel_jax_ket_action_conn_chunked(logpsi, pars, σ, O, chunk_size)¶
- local_value_kernel_jax_ket_action_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)¶
- local_value_kernel_jax_ket_action_chunked(logpsi, pars, σ, O, *, chunk_size=None)¶
- local_value_kernel_jax_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)¶
- local_value_squared_kernel_ket_action_chunked(logpsi, pars, σ, args, *, chunk_size=None)¶
- local_value_squared_kernel_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)¶
- local_value_kernel_squared_jax_chunked(logpsi, pars, σ, O, *, chunk_size=None)¶
- local_value_kernel_squared_jax_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)¶
- local_value_kernel_squared_jax_ket_action_chunked(logpsi, pars, σ, O, *, chunk_size=None)¶
- local_value_kernel_squared_jax_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)¶
- resolve_factored_local_kernel(local_kernel, *, chunked=False)¶
Return a factored local-estimator kernel if available for local_kernel.
Factored kernels consume precomputed logpsi(pars, σ) and therefore avoid recomputing the wavefunction on the reference samples for every operator.