|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import json |
8 | | -import logging |
9 | 8 | import math |
10 | 9 | import time |
11 | 10 | from collections.abc import Callable |
12 | 11 | from typing import TYPE_CHECKING, Any, Literal |
13 | 12 |
|
14 | 13 | from ._callback import CallbackInfo |
| 14 | +from ._catch import wrap_with_catch |
15 | 15 | from ._data import DataAccessor, SearchTracker |
16 | 16 | from ._memory import CachedObjectiveAdapter |
17 | 17 | from ._objective_adapter import ObjectiveAdapter |
|
25 | 25 | if TYPE_CHECKING: |
26 | 26 | import pandas as pd |
27 | 27 |
|
28 | | -logger = logging.getLogger(__name__) |
29 | | - |
30 | | - |
31 | | -def _wrap_with_catch( |
32 | | - objective_function: Callable, |
33 | | - catch: dict[type[Exception], int | float], |
34 | | -) -> Callable: |
35 | | - """Wrap objective function to catch exceptions and return fallback scores.""" |
36 | | - catch_types = tuple(catch.keys()) |
37 | | - |
38 | | - def wrapped(params): |
39 | | - try: |
40 | | - return objective_function(params) |
41 | | - except catch_types as e: |
42 | | - for exc_type, fallback_score in catch.items(): |
43 | | - if isinstance(e, exc_type): |
44 | | - logger.warning( |
45 | | - "Caught %s in objective function: %s. " |
46 | | - "Using fallback score: %s", |
47 | | - type(e).__name__, |
48 | | - e, |
49 | | - fallback_score, |
50 | | - ) |
51 | | - return fallback_score |
52 | | - raise |
53 | | - |
54 | | - return wrapped |
55 | | - |
56 | 28 |
|
57 | 29 | class Search(TimesTracker, SearchStatistics): |
58 | 30 | """ |
@@ -420,7 +392,7 @@ def _init_search( |
420 | 392 | catch: dict[type[Exception], int | float] | None = None, |
421 | 393 | ) -> None: |
422 | 394 | if catch: |
423 | | - objective_function = _wrap_with_catch(objective_function, catch) |
| 395 | + objective_function = wrap_with_catch(objective_function, catch) |
424 | 396 |
|
425 | 397 | if getattr(self, "optimum", "maximum") == "minimum": |
426 | 398 | self.objective_function = lambda pos: -objective_function(pos) |
|
0 commit comments