Skip to content

Commit 632d826

Browse files
committed
move function
1 parent cf1b459 commit 632d826

File tree

2 files changed

+34
-30
lines changed

2 files changed

+34
-30
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from collections.abc import Callable
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def wrap_with_catch(
10+
objective_function: Callable,
11+
catch: dict[type[Exception], int | float],
12+
) -> Callable:
13+
"""Wrap objective function to catch exceptions and return fallback scores."""
14+
catch_types = tuple(catch.keys())
15+
16+
def wrapped(params):
17+
try:
18+
return objective_function(params)
19+
except catch_types as e:
20+
for exc_type, fallback_score in catch.items():
21+
if isinstance(e, exc_type):
22+
logger.warning(
23+
"Caught %s in objective function: %s. "
24+
"Using fallback score: %s",
25+
type(e).__name__,
26+
e,
27+
fallback_score,
28+
)
29+
return fallback_score
30+
raise
31+
32+
return wrapped

src/gradient_free_optimizers/search.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from __future__ import annotations
66

77
import json
8-
import logging
98
import math
109
import time
1110
from collections.abc import Callable
1211
from typing import TYPE_CHECKING, Any, Literal
1312

1413
from ._callback import CallbackInfo
14+
from ._catch import wrap_with_catch
1515
from ._data import DataAccessor, SearchTracker
1616
from ._memory import CachedObjectiveAdapter
1717
from ._objective_adapter import ObjectiveAdapter
@@ -25,34 +25,6 @@
2525
if TYPE_CHECKING:
2626
import pandas as pd
2727

28-
logger = logging.getLogger(__name__)
29-
30-
31-
def _wrap_with_catch(
32-
objective_function: Callable,
33-
catch: dict[type[Exception], int | float],
34-
) -> Callable:
35-
"""Wrap objective function to catch exceptions and return fallback scores."""
36-
catch_types = tuple(catch.keys())
37-
38-
def wrapped(params):
39-
try:
40-
return objective_function(params)
41-
except catch_types as e:
42-
for exc_type, fallback_score in catch.items():
43-
if isinstance(e, exc_type):
44-
logger.warning(
45-
"Caught %s in objective function: %s. "
46-
"Using fallback score: %s",
47-
type(e).__name__,
48-
e,
49-
fallback_score,
50-
)
51-
return fallback_score
52-
raise
53-
54-
return wrapped
55-
5628

5729
class Search(TimesTracker, SearchStatistics):
5830
"""
@@ -420,7 +392,7 @@ def _init_search(
420392
catch: dict[type[Exception], int | float] | None = None,
421393
) -> None:
422394
if catch:
423-
objective_function = _wrap_with_catch(objective_function, catch)
395+
objective_function = wrap_with_catch(objective_function, catch)
424396

425397
if getattr(self, "optimum", "maximum") == "minimum":
426398
self.objective_function = lambda pos: -objective_function(pos)

0 commit comments

Comments
 (0)