Skip to content

Commit cf1b459

Browse files
committed
impl. catch error
1 parent 65df69e commit cf1b459

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

src/gradient_free_optimizers/search.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import json
8+
import logging
89
import math
910
import time
1011
from collections.abc import Callable
@@ -24,6 +25,34 @@
2425
if TYPE_CHECKING:
2526
import pandas as pd
2627

28+
logger = logging.getLogger(__name__)
29+
30+
31+
def _wrap_with_catch(
32+
objective_function: Callable,
33+
catch: dict[type[Exception], int | float],
34+
) -> Callable:
35+
"""Wrap objective function to catch exceptions and return fallback scores."""
36+
catch_types = tuple(catch.keys())
37+
38+
def wrapped(params):
39+
try:
40+
return objective_function(params)
41+
except catch_types as e:
42+
for exc_type, fallback_score in catch.items():
43+
if isinstance(e, exc_type):
44+
logger.warning(
45+
"Caught %s in objective function: %s. "
46+
"Using fallback score: %s",
47+
type(e).__name__,
48+
e,
49+
fallback_score,
50+
)
51+
return fallback_score
52+
raise
53+
54+
return wrapped
55+
2756

2857
class Search(TimesTracker, SearchStatistics):
2958
"""
@@ -140,6 +169,7 @@ def search(
140169
],
141170
optimum: Literal["maximum", "minimum"] = "maximum",
142171
callbacks: list[Callable[[CallbackInfo], bool | None]] | None = None,
172+
catch: dict[type[Exception], int | float] | None = None,
143173
) -> None:
144174
"""Run the optimization loop.
145175
@@ -288,6 +318,24 @@ def stop_early(info):
288318
opt.search(objective, n_iter=100,
289319
callbacks=[log_progress, stop_early])
290320
321+
catch : dict[type, float] or None, default=None
322+
Error handling for the objective function. Maps exception
323+
types to fallback scores. When the objective function raises
324+
a caught exception, the optimizer records the fallback score
325+
instead of crashing. Exception subclasses are matched via
326+
``isinstance``, so ``{Exception: ...}`` catches all.
327+
328+
The fallback score is in the user's original units (before
329+
any negation from ``optimum="minimum"``). Use
330+
``float('nan')`` or ``float('inf')`` to mark positions as
331+
invalid without inventing an artificial score.
332+
333+
Example::
334+
335+
catch = {ValueError: -1000, RuntimeError: float('nan')}
336+
337+
opt.search(objective, n_iter=100, catch=catch)
338+
291339
Examples
292340
--------
293341
Basic usage with default settings:
@@ -321,6 +369,7 @@ def stop_early(info):
321369
memory,
322370
memory_warm_start,
323371
verbosity,
372+
catch,
324373
)
325374

326375
for nth_trial in range(n_iter):
@@ -368,7 +417,11 @@ def _init_search(
368417
memory: bool,
369418
memory_warm_start: pd.DataFrame | None,
370419
verbosity: list[str] | Literal[False],
420+
catch: dict[type[Exception], int | float] | None = None,
371421
) -> None:
422+
if catch:
423+
objective_function = _wrap_with_catch(objective_function, catch)
424+
372425
if getattr(self, "optimum", "maximum") == "minimum":
373426
self.objective_function = lambda pos: -objective_function(pos)
374427
else:

0 commit comments

Comments
 (0)