diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 6ae85cfaf20..5847c377d1d 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -278,7 +278,9 @@ def _make_envs( transformed_in, transformed_out, N, - device="cpu", + p_env_device=None, + env_device=None, + # device="cpu", kwargs=None, local_mp_ctx=mp_ctx, ): @@ -286,13 +288,13 @@ def _make_envs( if not transformed_in: def create_env_fn(): - return GymEnv(env_name, frame_skip=frame_skip, device=device) + return GymEnv(env_name, frame_skip=frame_skip, device=env_device) else: if env_name == PONG_VERSIONED(): def create_env_fn(): - base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device) in_keys = list(base_env.observation_spec.keys(True, True))[:1] return TransformedEnv( base_env, @@ -303,7 +305,7 @@ def create_env_fn(): def create_env_fn(): - base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device) in_keys = list(base_env.observation_spec.keys(True, True))[:1] return TransformedEnv( @@ -316,9 +318,15 @@ def create_env_fn(): env0 = create_env_fn() env_parallel = ParallelEnv( - N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx + N, + create_env_fn, + create_env_kwargs=kwargs, + mp_start_method=local_mp_ctx, + device=p_env_device, + ) + env_serial = SerialEnv( + N, create_env_fn, create_env_kwargs=kwargs, device=p_env_device ) - env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs) for key in env0.observation_spec.keys(True, True): obs_key = key diff --git a/test/test_env.py b/test/test_env.py index 48509aff4bf..a1f84ca2692 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1471,12 +1471,29 @@ def make_env(): "transformed_in,transformed_out", [[True, True], [False, False]] ) # 1226: effociency @pytest.mark.parametrize("static_seed", [False, True]) + @pytest.mark.parametrize("penv_device", ["cpu", None]) + @pytest.mark.parametrize("env_device", ["cpu", None]) + @pytest.mark.parametrize("bwad", [True, False]) def test_parallel_env_seed( - self, env_name, frame_skip, transformed_in, transformed_out, static_seed + self, + env_name, + frame_skip, + transformed_in, + transformed_out, + static_seed, + penv_device, + env_device, + bwad, ): env_name = env_name() env_parallel, env_serial, _, _ = _make_envs( - env_name, frame_skip, transformed_in, transformed_out, 5 + env_name, + frame_skip, + transformed_in, + transformed_out, + 5, + p_env_device=penv_device, + env_device=env_device, ) try: out_seed_serial = env_serial.set_seed(0, static_seed=static_seed) @@ -1486,7 +1503,10 @@ def test_parallel_env_seed( torch.manual_seed(0) td_serial = env_serial.rollout( - max_steps=10, auto_reset=False, tensordict=td0_serial + max_steps=10, + auto_reset=False, + tensordict=td0_serial, + break_when_any_done=bwad, ).contiguous() key = "pixels" if "pixels" in td_serial.keys() else "observation" torch.testing.assert_close( @@ -1501,7 +1521,10 @@ def test_parallel_env_seed( torch.manual_seed(0) assert out_seed_parallel == out_seed_serial td_parallel = env_parallel.rollout( - max_steps=10, auto_reset=False, tensordict=td0_parallel + max_steps=10, + auto_reset=False, + tensordict=td0_parallel, + break_when_any_done=bwad, ).contiguous() torch.testing.assert_close( td_parallel[:, :-1].get(("next", key)), td_parallel[:, 1:].get(key) @@ -1677,7 +1700,7 @@ def test_parallel_env_device( frame_skip, transformed_in=transformed_in, transformed_out=transformed_out, - device=device, + env_device=device, N=N, local_mp_ctx="spawn", )