|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import json |
| 8 | +import logging |
8 | 9 | import math |
9 | 10 | import time |
10 | 11 | from collections.abc import Callable |
|
24 | 25 | if TYPE_CHECKING: |
25 | 26 | import pandas as pd |
26 | 27 |
|
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +def _wrap_with_catch( |
| 32 | + objective_function: Callable, |
| 33 | + catch: dict[type[Exception], int | float], |
| 34 | +) -> Callable: |
| 35 | + """Wrap objective function to catch exceptions and return fallback scores.""" |
| 36 | + catch_types = tuple(catch.keys()) |
| 37 | + |
| 38 | + def wrapped(params): |
| 39 | + try: |
| 40 | + return objective_function(params) |
| 41 | + except catch_types as e: |
| 42 | + for exc_type, fallback_score in catch.items(): |
| 43 | + if isinstance(e, exc_type): |
| 44 | + logger.warning( |
| 45 | + "Caught %s in objective function: %s. " |
| 46 | + "Using fallback score: %s", |
| 47 | + type(e).__name__, |
| 48 | + e, |
| 49 | + fallback_score, |
| 50 | + ) |
| 51 | + return fallback_score |
| 52 | + raise |
| 53 | + |
| 54 | + return wrapped |
| 55 | + |
27 | 56 |
|
28 | 57 | class Search(TimesTracker, SearchStatistics): |
29 | 58 | """ |
@@ -140,6 +169,7 @@ def search( |
140 | 169 | ], |
141 | 170 | optimum: Literal["maximum", "minimum"] = "maximum", |
142 | 171 | callbacks: list[Callable[[CallbackInfo], bool | None]] | None = None, |
| 172 | + catch: dict[type[Exception], int | float] | None = None, |
143 | 173 | ) -> None: |
144 | 174 | """Run the optimization loop. |
145 | 175 |
|
@@ -288,6 +318,24 @@ def stop_early(info): |
288 | 318 | opt.search(objective, n_iter=100, |
289 | 319 | callbacks=[log_progress, stop_early]) |
290 | 320 |
|
| 321 | + catch : dict[type, float] or None, default=None |
| 322 | + Error handling for the objective function. Maps exception |
| 323 | + types to fallback scores. When the objective function raises |
| 324 | + a caught exception, the optimizer records the fallback score |
| 325 | + instead of crashing. Exception subclasses are matched via |
| 326 | + ``isinstance``, so ``{Exception: ...}`` catches all. |
| 327 | +
|
| 328 | + The fallback score is in the user's original units (before |
| 329 | + any negation from ``optimum="minimum"``). Use |
| 330 | + ``float('nan')`` or ``float('inf')`` to mark positions as |
| 331 | + invalid without inventing an artificial score. |
| 332 | +
|
| 333 | + Example:: |
| 334 | +
|
| 335 | + catch = {ValueError: -1000, RuntimeError: float('nan')} |
| 336 | +
|
| 337 | + opt.search(objective, n_iter=100, catch=catch) |
| 338 | +
|
291 | 339 | Examples |
292 | 340 | -------- |
293 | 341 | Basic usage with default settings: |
@@ -321,6 +369,7 @@ def stop_early(info): |
321 | 369 | memory, |
322 | 370 | memory_warm_start, |
323 | 371 | verbosity, |
| 372 | + catch, |
324 | 373 | ) |
325 | 374 |
|
326 | 375 | for nth_trial in range(n_iter): |
@@ -368,7 +417,11 @@ def _init_search( |
368 | 417 | memory: bool, |
369 | 418 | memory_warm_start: pd.DataFrame | None, |
370 | 419 | verbosity: list[str] | Literal[False], |
| 420 | + catch: dict[type[Exception], int | float] | None = None, |
371 | 421 | ) -> None: |
| 422 | + if catch: |
| 423 | + objective_function = _wrap_with_catch(objective_function, catch) |
| 424 | + |
372 | 425 | if getattr(self, "optimum", "maximum") == "minimum": |
373 | 426 | self.objective_function = lambda pos: -objective_function(pos) |
374 | 427 | else: |
|
0 commit comments