neuralqx.vqs.mc.kernels module

This file overrides some of the implementation of some common kernels used by MCState and MCMixedState in NetKet

NOTE: part(s) of, or the entire content, of this file is obtained from NetKet’s source code

the original copyright mentioned above applies.

batch_discrete_kernel(kernel)

Batch a decorator that only works with 1 sample so that it works with a batch of samples.

Works only for discrete-kernels who take two args as inputs

local_value_kernel(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_jax(logpsi, pars, σ, O)

local_value kernel for MCState for jax-compatible operators

local_value_kernel_factored(logpsi_σ, logpsi, pars, σ, args)

Bra-action local estimator using precomputed logpsi(pars, σ).

local_value_kernel_jax_factored(logpsi_σ, logpsi, pars, σ, O)

JAX-operator bra-action local estimator using precomputed logpsi(pars, σ).

local_value_squared_kernel(logpsi, pars, σ, args)

local_value kernel for MCState and Squared (generic) operators

local_value_kernel_squared_jax(logpsi, pars, σ, O)

Squared‐operator local estimator for any DiscreteJaxOperator

local_value_squared_kernel_factored(logpsi_σ, logpsi, pars, σ, args)
local_value_kernel_squared_jax_factored(logpsi_σ, logpsi, pars, σ, O)
local_value_op_op_cost(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_chunked(logpsi, pars, σ, args, *, chunk_size=None)

local_value kernel for MCState and generic operators

local_value_kernel_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)

Chunked bra-action estimator using precomputed logpsi(pars, σ).

local_value_kernel_jax_conn_chunked(logpsi, pars, σ, O, chunk_size)

local_value kernel for MCState for jax-compatible operators

local_value_kernel_jax_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)

Chunked JAX-operator bra-action estimator using precomputed logpsi(pars, σ).

local_value_squared_kernel_chunked(logpsi, pars, σ, args, *, chunk_size=None)

local_value kernel for MCState and Squared (generic) operators

local_value_squared_kernel_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_squared_jax_conn_chunked(logpsi, pars, σ, O, chunk_size)

Squared version with the same splitting strategy

local_value_kernel_squared_jax_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)
local_value_op_op_cost_chunked(logpsi, pars, σ, args, *, chunk_size=None)

local_value kernel for MCMixedState and generic operators

local_value_kernel_jax_chunked(logpsi, pars, σ, O, *, chunk_size=None)

local_value kernel for MCState and jaxcoompatible operators

local_value_kernel_jax_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_penalty_cost(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_penalty_cost_factored(logpsi_σ, logpsi, pars, σ, args)
local_value_kernel_penalty_cost_chunked(logpsi, pars, σ, args, *, chunk_size=None)

Just the regular kernel for a local operator for now

local_value_kernel_penalty_cost_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_variance(logpsi, pars, σ, args)

Local estimator for a VarianceObservable.

args = (local_kernel_O, args_O, local_kernel_O2, args_O2)

Returns L_σ = O²_loc(σ) − ⟨O⟩_MC² (real part only).

The mean ⟨O⟩ is taken over the same mini-batch σ that NetKet’s Monte-Carlo integration is using, which is the usual unbiased MC estimator employed in the original NetKet routine.

affine_penalty_cost_kernel(logpsi, inner_kernel, pars, σ, inner_args, scale, shift)

Jitted per-sample local estimator for expectation-level penalty costs.

C_loc(σ) = scale * O_loc(σ) + shift

Here logpsi and inner_kernel are static functions and inner_args, scale and shift are non-static runtime values. This preserves one stable compiled executable while allowing the affine coefficients to change from batch to batch.

InverseExpectationCost is one specialization of this more general protocol.

Return type:

Array

volume_cost_kernel(logpsi, inner_kernel, pars, σ, inner_args, fprime, g)

Backward-compatible name for the affine penalty local estimator.

Historically this helper was specific to InverseExpectationCost:

C_loc(σ) = f’(<V>) * V_loc(σ) + g

Here logpsi and inner_kernel are static (functions) and inner_args, fprime, g are non-static runtime values (device arrays).

This ensures a single stable compiled executable regardless of how often you re-enter training, while still allowing <V> to change batch-to-batch

Return type:

Array

local_value_kernel_ket_action(logpsi, pars, σ, args)

local_value kernel for MCState and generic operators

local_value_kernel_jax_ket_action(logpsi, pars, σ, O)

Same as local_value_kernel_ket_action, but for DiscreteJaxOperator-style padded connections.

local_value_kernel_ket_action_factored(logpsi_σ, logpsi, pars, σ, args)

Ket-action local estimator using precomputed logpsi(pars, σ).

local_value_kernel_jax_ket_action_factored(logpsi_σ, logpsi, pars, σ, O)

JAX ket-action local estimator using precomputed logpsi(pars, σ).

local_value_squared_kernel_ket_action(logpsi, pars, σ, args)
local_value_kernel_squared_jax_ket_action(logpsi, pars, σ, O)
local_value_squared_kernel_ket_action_factored(logpsi_σ, logpsi, pars, σ, args)
local_value_kernel_squared_jax_ket_action_factored(logpsi_σ, logpsi, pars, σ, O)
local_value_kernel_ket_action_chunked(logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_jax_ket_action_conn_chunked(logpsi, pars, σ, O, chunk_size)
local_value_kernel_jax_ket_action_conn_chunked_factored(logpsi_σ, logpsi, pars, σ, O, chunk_size)
local_value_kernel_jax_ket_action_chunked(logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_jax_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
local_value_squared_kernel_ket_action_chunked(logpsi, pars, σ, args, *, chunk_size=None)
local_value_squared_kernel_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, args, *, chunk_size=None)
local_value_kernel_squared_jax_chunked(logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_squared_jax_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_squared_jax_ket_action_chunked(logpsi, pars, σ, O, *, chunk_size=None)
local_value_kernel_squared_jax_ket_action_chunked_factored(logpsi_σ, logpsi, pars, σ, O, *, chunk_size=None)
resolve_factored_local_kernel(local_kernel, *, chunked=False)

Return a factored local-estimator kernel if available for local_kernel.

Factored kernels consume precomputed logpsi(pars, σ) and therefore avoid recomputing the wavefunction on the reference samples for every operator.