Skip to content

Commit ae53374

Browse files
committed
add constr. tests for mixed spaces and multiple constr.
1 parent ba41cb4 commit ae53374

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

tests/test_main/test_optimizers/test_constr_opt.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,67 @@ def test_constraint_tracking_consistency(Optimizer):
101101
assert n_current <= n_new
102102
assert n_best == n_best_scores
103103
assert n_best <= n_new
104+
105+
106+
@pytest.mark.parametrize(*optimizers)
107+
def test_constraint_categorical(Optimizer):
108+
search_space = {
109+
"algo": ["adam", "sgd", "rmsprop", "adagrad"],
110+
"lr": np.arange(0.001, 0.1, 0.001),
111+
}
112+
opt = Optimizer(
113+
search_space,
114+
constraints=[lambda p: p["algo"] in ["adam", "sgd"]],
115+
random_state=42,
116+
)
117+
opt.search(lambda p: -p["lr"], n_iter=30, verbosity=False)
118+
119+
values = opt.search_data["algo"].values
120+
assert all(v in ["adam", "sgd"] for v in values)
121+
122+
123+
@pytest.mark.parametrize(*optimizers)
124+
def test_constraint_mixed_dimensions(Optimizer):
125+
"""Continuous tuple + discrete array + categorical list."""
126+
search_space = {
127+
"x": np.arange(-10, 10, 1),
128+
"y": (-5.0, 5.0),
129+
"mode": ["fast", "slow", "medium"],
130+
}
131+
opt = Optimizer(
132+
search_space,
133+
constraints=[lambda p: p["x"] > 0, lambda p: p["y"] > 0],
134+
random_state=42,
135+
)
136+
opt.search(
137+
lambda p: -(p["x"] ** 2 + p["y"] ** 2),
138+
n_iter=50,
139+
verbosity=False,
140+
)
141+
142+
data = opt.search_data
143+
assert np.all(data["x"].values > 0)
144+
assert np.all(data["y"].values > 0)
145+
146+
147+
@pytest.mark.parametrize(*optimizers)
148+
def test_constraint_cross_parameter(Optimizer):
149+
"""Single constraint referencing multiple parameters."""
150+
search_space = {
151+
"x1": np.arange(-10, 10, 1),
152+
"x2": np.arange(-10, 10, 1),
153+
}
154+
opt = Optimizer(
155+
search_space,
156+
constraints=[lambda p: p["x1"] + p["x2"] > 0],
157+
random_state=42,
158+
)
159+
opt.search(
160+
lambda p: -(p["x1"] ** 2 + p["x2"] ** 2),
161+
n_iter=50,
162+
verbosity=False,
163+
)
164+
165+
data = opt.search_data
166+
sums = data["x1"].values + data["x2"].values
167+
assert np.all(sums > 0)

0 commit comments

Comments
 (0)