neuralqx.vqs.mc.kernels module

This file overrides some of the implementation of some common kernels used by MCState and MCMixedState in NetKet

NOTE: part(s) of, or the entire content, of this file is obtained from NetKet’s source code

the original copyright mentioned above applies.

batch_discrete_kernel(kernel)

Batch a decorator that only works with 1 sample so that it works with a batch of samples.

Works only for discrete-kernels who take two args as inputs

local_value_kernel(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_jax(logpsi, pars, σ, O)

local_value kernel for MCState for jax-compatible operators

local_value_kernel_factored(logpsi_σ, logpsi, pars, σ, args)

Bra-action local estimator using precomputed logpsi(pars, σ).

local_value_kernel_jax_factored(logpsi_σ, logpsi, pars, σ, O)

JAX-operator bra-action local estimator using precomputed logpsi(pars, σ).

local_value_squared_kernel(logpsi, pars, σ, args)

local_value kernel for MCState and Squared (generic) operators

local_value_kernel_squared_jax(logpsi, pars, σ, O)

Squared‐operator local estimator for any DiscreteJaxOperator

local_value_squared_kernel_factored(logpsi_σ, logpsi, pars, σ, args)
local_value_kernel_squared_jax_factored(logpsi_σ, logpsi, pars, σ, O)
local_value_op_op_cost(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_chunked(logpsi, pars, σ, args, *, chunk_size=None)

local_value kernel for MCState and generic operators

local_value_kernel_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)

Chunked bra-action estimator using precomputed logpsi(pars, σ).

local_value_kernel_jax_conn_chunked(logpsi, pars, σ, O, chunk_size)

local_value kernel for MCState for jax-compatible operators

local_value_kernel_jax_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)

Chunked JAX-operator bra-action estimator using precomputed logpsi(pars, σ).

local_value_squared_kernel_chunked(logpsi, pars, σ, args, *, chunk_size=None)

local_value kernel for MCState and Squared (generic) operators

local_value_squared_kernel_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_squared_jax_conn_chunked(logpsi, pars, σ, O, chunk_size)

Squared version with the same splitting strategy

local_value_kernel_squared_jax_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)
local_value_op_op_cost_chunked(logpsi, pars, σ, args, *, chunk_size=None)

local_value kernel for MCMixedState and generic operators

local_value_kernel_jax_chunked(logpsi, pars, σ, O, *, chunk_size=None)

local_value kernel for MCState and jaxcoompatible operators

local_value_kernel_jax_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_penalty_cost(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_penalty_cost_factored(logpsi_σ, logpsi, pars, σ, args)
local_value_kernel_penalty_cost_chunked(logpsi, pars, σ, args, *, chunk_size=None)

Just the regular kernel for a local operator for now

local_value_kernel_penalty_cost_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_variance(logpsi, pars, σ, args)

Local estimator for a VarianceObservable.

args = (local_kernel_O, args_O, local_kernel_O2, args_O2)

Returns L_σ = O²_loc(σ) − ⟨O⟩_MC² (real part only).

The mean ⟨O⟩ is taken over the same mini-batch σ that NetKet’s Monte-Carlo integration is using, which is the usual unbiased MC estimator employed in the original NetKet routine.

volume_cost_kernel(logpsi, inner_kernel, pars, σ, inner_args, fprime, g)

Jitted per-sample local estimator for the InverseExpectationCost type

C_loc(σ) = f’(<V>) * V_loc(σ) + g

here logpsi and inner_kernel are static (functions) and inner_args, fprime, g are non-static runtime values (device arrays).

This ensures a single stable compiled executable regardless of how often you re-enter training, while still allowing <V> to change batch-to-batch

Return type:

Array

local_value_kernel_ket_action(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_jax_ket_action(logpsi, pars, σ, O)

Same as local_value_kernel_ket_action, but for DiscreteJaxOperator-style padded connections.

local_value_kernel_ket_action_factored(logpsi_σ, logpsi, pars, σ, args)

Ket-action local estimator using precomputed logpsi(pars, σ).

local_value_kernel_jax_ket_action_factored(logpsi_σ, logpsi, pars, σ, O)

JAX ket-action local estimator using precomputed logpsi(pars, σ).

local_value_squared_kernel_ket_action(logpsi, pars, σ, args)
local_value_kernel_squared_jax_ket_action(logpsi, pars, σ, O)
local_value_squared_kernel_ket_action_factored(logpsi_σ, logpsi, pars, σ, args)
local_value_kernel_squared_jax_ket_action_factored(logpsi_σ, logpsi, pars, σ, O)
local_value_kernel_ket_action_chunked(logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_jax_ket_action_conn_chunked(logpsi, pars, σ, O, chunk_size)
local_value_kernel_jax_ket_action_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)
local_value_kernel_jax_ket_action_chunked(logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_jax_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
local_value_squared_kernel_ket_action_chunked(logpsi, pars, σ, args, *, chunk_size=None)
local_value_squared_kernel_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_squared_jax_chunked(logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_squared_jax_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_squared_jax_ket_action_chunked(logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_squared_jax_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
resolve_factored_local_kernel(local_kernel, *, chunked=False)

Return a factored local-estimator kernel if available for local_kernel.

Factored kernels consume precomputed logpsi(pars, σ) and therefore avoid recomputing the wavefunction on the reference samples for every operator.