neuralqx.driver.abstract_variational_driver module¶
A subclass of NetKet’s AbstractVariationalDriver which provides variational inference with profiling.
- class AbstractVariationalDriver(variational_state, optimizer, minimized_quantity_name='loss')¶
Bases:
AbstractVariationalDriver- iter(n_steps, step=1)¶
Returns a generator which advances the VMC optimization, yielding after every step_size steps.
- run(n_iter, out=(), obs=None, step_size=1, show_progress=True, save_params_every=50, write_every=50, callback=<function AbstractVariationalDriver.<lambda>>, timeit=False)¶
Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.
It is possible to control more specifically what quantities are logged, when to stop the optimisation, or to execute arbitrary code at every step by specifying one or more callbacks, which are passed as a list of functions to the keyword argument callback.
Callbacks are functions that follow this signature:
def callback(step, log_data, driver) -> bool: ... return True/False
If a callback returns True, the optimisation continues, otherwise it is stopped. The log_data is a dictionary that can be modified in-place to change what is logged at every step. For example, this can be used to log additional quantities such as the acceptance rate of a sampler.
Loggers are specified as an iterable passed to the keyword argument out. If only a string is specified, this will create by default a
nk.logging.JsonLog. To know about the output format check its documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.When running among multiple MPI ranks/Jax devices, the logging logic is executed on all nodes, but only root-rank loggers should write to files or do expensive I/O operations.
Note
Before NetKet 3.15, loggers where automatically ‘ignored’ on non-root ranks. However, starting with NetKet 3.15 it is the responsability of a logger to check if it is executing on a non-root rank, and to ‘do nothing’ if that is the case.
The change was required to work correctly and efficiently with sharding. It will only affect users that were defining custom loggers themselves.
- Parameters:
n_iter (
int) – the total number of iterations to be performed during this run.out (
AbstractLog|Iterable[AbstractLog] |str|None) – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.obs (
dict[str,AbstractObservable] |None) – An iterable containing all observables that should be computedstep_size (
int) – Every how many steps should observables be logged to disk (default=1)callback (
Callable[[int,dict,AbstractVariationalDriver],bool] |Iterable[Callable[[int,dict,AbstractVariationalDriver],bool]]) – Callable or list of callable callback functions to stop training given a conditionshow_progress (
bool) – If true displays a progress bar (default=True)save_params_every (
int) – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)write_every (
int) – Every how many steps the json data should be flushed to disk (ignored if logger is provided)timeit (
bool) – If True, provide timing information.
- update_parameters(dp)¶
Updates the parameters of the machine using the optimizer in this driver
- Parameters:
dp – the pytree containing the updates to the parameters