Skip to content

Commit 86123ac

Browse files
[RLlib] - Fix broken tests for algorithm checkpoint, policy, and actor manager. (#52503)
1 parent 770035b commit 86123ac

File tree

3 files changed

+12
-17
lines changed

3 files changed

+12
-17
lines changed

rllib/algorithms/tests/test_algorithm_export_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def save_test(alg_name, framework="tf", multi_agent=False):
2121
config = (
2222
get_trainable_cls(alg_name)
2323
.get_default_config()
24+
.api_stack(
25+
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
26+
)
2427
.framework(framework)
2528
# Switch on saving native DL-framework (tf, torch) model files.
2629
.checkpointing(export_native_model_files=True)
@@ -63,7 +66,7 @@ def save_test(alg_name, framework="tf", multi_agent=False):
6366

6467
# Test loading exported model and perform forward pass.
6568
filename = os.path.join(model_dir, "model.pt")
66-
model = torch.load(filename)
69+
model = torch.load(filename, weights_only=False)
6770
assert model
6871
results = model(
6972
input_dict={"obs": torch.from_numpy(test_obs)},

rllib/policy/tests/test_policy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ def tearDownClass(cls) -> None:
1919
ray.shutdown()
2020

2121
def test_policy_get_and_set_state(self):
22-
config = PPOConfig().environment("CartPole-v1")
22+
config = (
23+
PPOConfig()
24+
.environment("CartPole-v1")
25+
.api_stack(
26+
enable_env_runner_and_connector_v2=False,
27+
enable_rl_module_and_learner=False,
28+
)
29+
)
2330
algo = config.build()
2431
policy = algo.get_policy()
2532
state1 = policy.get_state()

rllib/utils/tests/test_actor_manager.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -331,21 +331,6 @@ def f(id, _):
331331

332332
manager.clear()
333333

334-
def test_len_of_func_not_match_len_of_actors(self):
335-
"""Test healthy only mode works when a list of funcs are provided."""
336-
actors = [Actor.remote(i) for i in range(4)]
337-
manager = FaultTolerantActorManager(actors=actors)
338-
339-
def f(id, _):
340-
return id
341-
342-
func = [functools.partial(f, i) for i in range(3)]
343-
344-
with self.assertRaisesRegexp(AssertionError, "same number of callables") as _:
345-
manager.foreach_actor_async(func, healthy_only=True)
346-
347-
manager.clear()
348-
349334
def test_probe_unhealthy_actors(self):
350335
"""Test probe brings back unhealthy actors."""
351336
actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]

0 commit comments

Comments
 (0)