Source code for glompo.optimizers.nevergrad

import warnings
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing import Event, Queue
from multiprocessing.connection import Connection
from pathlib import Path
from typing import Callable, Optional, Sequence, Set, Union

import nevergrad as ng
import numpy as np

from .baseoptimizer import BaseOptimizer, MinimizeResult

__all__ = ('Nevergrad',)


[docs]class Nevergrad(BaseOptimizer): """ Provides access to the optimizers available through the `nevergrad <https://facebookresearch.github.io/nevergrad/>`_ package. Parameters ---------- Inherited, _opt_id _signal_pipe _results_queue _pause_flag workers backend is_log_detailed See :class:`.BaseOptimizer`. optimizer String key to the desired optimizer. See nevergrad documentation for a list of available algorithms. zero Will stop the optimization when this cost function value is reached. """ def __init__(self, _opt_id: int = None, _signal_pipe: Connection = None, _results_queue: Queue = None, _pause_flag: Event = None, workers: int = 1, backend: str = 'processes', is_log_detailed: bool = False, optimizer: str = 'TBPSA', zero: float = -float('inf')): super().__init__(_opt_id, _signal_pipe, _results_queue, _pause_flag, workers, backend, is_log_detailed) self.opt_algo = ng.optimizers.registry[optimizer] self.optimizer = None if self.opt_algo.no_parallelization is True: warnings.warn("The selected algorithm does not support parallel execution, workers overwritten and set to" " one.", RuntimeWarning) self.workers = 1 self.zero = zero self.stop = False self.ng_callbacks = None def minimize(self, function, x0, bounds, callbacks=None, **kwargs) -> MinimizeResult: lower, upper = np.transpose(bounds) parametrization = ng.p.Array(init=x0) parametrization.set_bounds(lower, upper) if self.is_restart and self.optimizer: self.ng_callbacks.parent = self self.ng_callbacks.callbacks = callbacks self.stop = False self.logger.debug("Loaded nevergrad optimizer") else: self.optimizer = self.opt_algo(parametrization=parametrization, budget=int(4e50), num_workers=self.workers, **kwargs) self.ng_callbacks = _NevergradCallbacksWrapper(self, callbacks) self.logger.debug("Created nevergrad optimizer object") self.logger.debug("Created callbacks object") self.optimizer.register_callback('tell', self.ng_callbacks) self.logger.debug("Callbacks registered with optimizer") if self.workers > 1: self.logger.debug("Executing within pool with %d workers", self.workers) if self._backend == 'processes': with ProcessPoolExecutor(max_workers=self.workers) as executor: opt_vec = self.optimizer.minimize(function, executor=executor, batch_mode=False) else: with ThreadPoolExecutor(max_workers=self.workers) as executor: opt_vec = self.optimizer.minimize(function, executor=executor, batch_mode=False) else: self.logger.debug("Executing serially.") opt_vec = self.optimizer.minimize(function, batch_mode=False) self.logger.debug("Optimization complete. Formatting into MinimizeResult instance") results = MinimizeResult() results.x = opt_vec.value results.fx = function(results.x) if results.fx < float('inf'): results.success = True return results def callstop(self, *args): self.stop = True def checkpoint_save(self, path: Union[Path, str], force: Optional[Set[str]] = None): # Remove attributes which should not be saved if self.ng_callbacks: self.ng_callbacks.parent = None callbacks = self.ng_callbacks.callbacks self.ng_callbacks.callbacks = None super().checkpoint_save(path, {'opt_algo', 'ng_callbacks'}) # Restore attributes if self.ng_callbacks: self.ng_callbacks.parent = self self.ng_callbacks.callbacks = callbacks
class _NevergradCallbacksWrapper: """ Wraps all the components needed by GloMPO to be called after each iteration into a single object which can be registered as a nevergrad callback. """ def __init__(self, parent: Nevergrad, callbacks: Union[None, Callable[[ng.optimizers.base.Optimizer, Sequence[float], float], bool], Sequence[Callable[[ng.optimizers.base.Optimizer, Sequence[float], float], bool]]] = None): self.parent = parent self.i_fcalls = 0 if callable(callbacks): self.callbacks = [callbacks] elif callbacks is None: self.callbacks = [] def __call__(self, opt: ng.optimizers.base.Optimizer, x: ng.p.Array, fx: float): if not self.parent.stop: stop_cond = None # Normal termination condition if fx >= 1e30 or fx <= self.parent.zero: stop_cond = f"Nevergrad termination conditions:\n" \ f"(fx >= 1e30) = {fx >= 1e30}\n" \ f"(fx <= {self.parent.zero}) = {fx <= self.parent.zero}" self.parent.logger.debug("Stop = %s at convergence condition", bool(stop_cond)) # User sent callbacks if not stop_cond and any([cb(opt, x, fx) for cb in self.callbacks]): stop_cond = "Direct user callbacks" self.parent.logger.debug("Stop = %s at user callbacks (iter: %d)", bool(stop_cond), opt.num_tell) # GloMPO specific callbacks if self.parent._results_queue: self.parent._pause_signal.wait() self.parent.check_messages() if not stop_cond and self.parent.stop: stop_cond = "GloMPO termination signal." self.parent.logger.debug("Stop = %s after message check from manager", bool(stop_cond)) self.i_fcalls = opt.num_tell + 1 if stop_cond: self.parent.logger.debug("Stop is True so shutting down optimizer.") self.parent.stop = True opt._num_ask = opt.budget - 1 self.parent.message_manager(0, stop_cond) @property def n_iter(self): return self.i_fcalls