@@ -101,3 +101,67 @@ def test_constraint_tracking_consistency(Optimizer):
101101 assert n_current <= n_new
102102 assert n_best == n_best_scores
103103 assert n_best <= n_new
104+
105+
106+ @pytest .mark .parametrize (* optimizers )
107+ def test_constraint_categorical (Optimizer ):
108+ search_space = {
109+ "algo" : ["adam" , "sgd" , "rmsprop" , "adagrad" ],
110+ "lr" : np .arange (0.001 , 0.1 , 0.001 ),
111+ }
112+ opt = Optimizer (
113+ search_space ,
114+ constraints = [lambda p : p ["algo" ] in ["adam" , "sgd" ]],
115+ random_state = 42 ,
116+ )
117+ opt .search (lambda p : - p ["lr" ], n_iter = 30 , verbosity = False )
118+
119+ values = opt .search_data ["algo" ].values
120+ assert all (v in ["adam" , "sgd" ] for v in values )
121+
122+
123+ @pytest .mark .parametrize (* optimizers )
124+ def test_constraint_mixed_dimensions (Optimizer ):
125+ """Continuous tuple + discrete array + categorical list."""
126+ search_space = {
127+ "x" : np .arange (- 10 , 10 , 1 ),
128+ "y" : (- 5.0 , 5.0 ),
129+ "mode" : ["fast" , "slow" , "medium" ],
130+ }
131+ opt = Optimizer (
132+ search_space ,
133+ constraints = [lambda p : p ["x" ] > 0 , lambda p : p ["y" ] > 0 ],
134+ random_state = 42 ,
135+ )
136+ opt .search (
137+ lambda p : - (p ["x" ] ** 2 + p ["y" ] ** 2 ),
138+ n_iter = 50 ,
139+ verbosity = False ,
140+ )
141+
142+ data = opt .search_data
143+ assert np .all (data ["x" ].values > 0 )
144+ assert np .all (data ["y" ].values > 0 )
145+
146+
147+ @pytest .mark .parametrize (* optimizers )
148+ def test_constraint_cross_parameter (Optimizer ):
149+ """Single constraint referencing multiple parameters."""
150+ search_space = {
151+ "x1" : np .arange (- 10 , 10 , 1 ),
152+ "x2" : np .arange (- 10 , 10 , 1 ),
153+ }
154+ opt = Optimizer (
155+ search_space ,
156+ constraints = [lambda p : p ["x1" ] + p ["x2" ] > 0 ],
157+ random_state = 42 ,
158+ )
159+ opt .search (
160+ lambda p : - (p ["x1" ] ** 2 + p ["x2" ] ** 2 ),
161+ n_iter = 50 ,
162+ verbosity = False ,
163+ )
164+
165+ data = opt .search_data
166+ sums = data ["x1" ].values + data ["x2" ].values
167+ assert np .all (sums > 0 )
0 commit comments