neuralqx.profile package

Instrumentation and profiling tools for neuraLQX workflows.

Typical usage (public API):

from neuralqx.profiling import profile, section, step

@profile(cat=”vmc”) def vmc_iteration(…):

with section(“sampling”, cat=”vmc”):

with section(“expectation”, cat=”vmc”):

for i in range(n_steps):
with step(i, name=”VMC”):

vmc_iteration(…)

Configuration: - Environment variables:

NQX_PROFILE=1 NQX_PROFILE_DIR=/path NQX_PROFILE_TRACE=1 NQX_PROFILE_NVTX=1 NQX_PROFILE_JAX_TRACE=1 NQX_PROFILE_METRICS=1 NQX_PROFILE_SYNC=1

class Profiler

Bases: object

Process-local profiler with hierarchical section timing and optional trace/telemetry.

Thread-safety:
  • Each thread gets an independent stack to compute exclusive time.

  • Summary aggregation merges into a shared tree. Updates are guarded by a lock. (Summary updates happen at section exit only; trace emission is optional.)

enabled()
Return type:

bool

section(name, cat='', *, args=None, flops=0.0, bytes=0.0)

Create a profiling section context manager.

step(step_num, name='step', cat='step', *, args=None)

Mark a repeated step. Adds metadata and (optionally) JAX step trace annotation.

flush()
Return type:

None

get_profiler()
Return type:

Profiler | _DisabledProfiler

enabled()

True if profiling is enabled.

Important

  • This function is intentionally NOT cached.

We prefer reading the environment variable directly because neuraLQX’s ConfigManager always keeps os.environ[“NQX_PROFILE”] in sync.

Return type:

bool

flush()

Force writing profiling artifacts now.

In scripts the profiler flushes automatically at process exit. In Jupyter notebooks the kernel might stay alive, so call this to create output files immediately.

Return type:

None

profile(name=None, cat='', *, sync=None, args=None, flops=0.0, bytes=0.0)

Decorator to profile a function.

sync:

If True, block_until_ready() the returned value inside the timed region. This provides device-time-ish measurements for JAX, but can introduce delays. If None, uses global config (default False).

section(name, cat='', *, args=None, flops=0.0, bytes=0.0)

Context manager to profile a section of code.

Parameters:
  • name (str) – Section name

  • cat (str) – Category (e.g. “vmc”, “sampling”, “jax”, …)

  • args (Optional[Dict[str, Any]]) – Additional metadata to attach to trace events

  • flops (float) – Optional counters for roofline-style derived metrics

  • bytes (float) – Optional counters for roofline-style derived metrics

step(step_num, name='step', cat='step', *, args=None)

Context manager to mark a repeated step (e.g. VMC iteration number).

wrap_callable(fn, *, name=None, cat='', sync=None, args=None, flops=0.0, bytes=0.0, deep=None, deep_cat='python', deep_include=None, deep_exclude=None, deep_max_depth=None)

Wrap an arbitrary callable with profiling sections.

Useful to instrument external library functions without editing their source.

Return type:

Callable[..., Any]

profile_call(fn, *a, name=None, cat='', sync=None, args=None, flops=0.0, bytes=0.0, deep=None, deep_cat='python', deep_include=None, deep_exclude=None, deep_max_depth=None, **k)

Execute one profiled callable invocation.

Return type:

Any

patch_method(obj, method_name, *, name=None, cat='', sync=None, args=None, flops=0.0, bytes=0.0, deep=None, deep_cat='python', deep_include=None, deep_exclude=None, deep_max_depth=None)

Temporarily patch obj.method_name with a profiled wrapper.

patch_attr(target, attr_path, *, name=None, cat='', sync=None, args=None, flops=0.0, bytes=0.0, deep=None, deep_cat='python', deep_include=None, deep_exclude=None, deep_max_depth=None)

Temporarily patch a callable attribute path on an object or module.

attr_path can be dotted, for example "sampler.sample".

python_call_trace(*, cat='python', include_prefixes=None, exclude_prefixes=None, max_depth=None)

Profile nested Python function calls in the current thread.

This is intentionally opt-in and can add noticeable overhead.

Submodules