Skip to content

Commit afb81de

Browse files
author
Vincent Moens
committed
[BugFix] Fix partial device transfers in collector
ghstack-source-id: 2cd74c2 Pull Request resolved: #2703
1 parent c5f1565 commit afb81de

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

test/test_collector.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2233,7 +2233,9 @@ def __init__(self):
22332233
self.out_keys = ["action"]
22342234

22352235
def forward(self, td):
2236-
td["action"] = (self.param + self.buf).expand(td.shape)
2236+
td["action"] = (self.param + self.buf.to(self.param.device)).expand(
2237+
td.shape
2238+
)
22372239
return td
22382240

22392241
@pytest.mark.parametrize(
@@ -2288,6 +2290,64 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
22882290
col.shutdown()
22892291
del col
22902292

2293+
@pytest.mark.parametrize(
2294+
"collector",
2295+
[
2296+
functools.partial(MultiSyncDataCollector, cat_results="stack"),
2297+
MultiaSyncDataCollector,
2298+
],
2299+
)
2300+
@pytest.mark.parametrize("give_weights", [True, False])
2301+
@pytest.mark.parametrize(
2302+
"policy_device,env_device",
2303+
[
2304+
["cpu", get_default_devices()[0]],
2305+
[get_default_devices()[0], "cpu"],
2306+
# ["cpu", "cuda:0"], # 1226: faster execution
2307+
# ["cuda:0", "cpu"],
2308+
# ["cuda", "cuda:0"],
2309+
# ["cuda:0", "cuda"],
2310+
],
2311+
)
2312+
def test_param_sync_mixed_device(
2313+
self, give_weights, collector, policy_device, env_device
2314+
):
2315+
with torch.device("cpu"):
2316+
policy = TestUpdateParams.Policy()
2317+
policy.param = nn.Parameter(policy.param.data.to(policy_device))
2318+
assert policy.buf.device == torch.device("cpu")
2319+
2320+
env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device))
2321+
device = env().device
2322+
env = [env]
2323+
col = collector(
2324+
env, policy, device=device, total_frames=200, frames_per_batch=10
2325+
)
2326+
try:
2327+
for i, data in enumerate(col):
2328+
if i == 0:
2329+
assert (data["action"] == 0).all()
2330+
# update policy
2331+
policy.param.data += 1
2332+
policy.buf.data += 2
2333+
assert policy.buf.device == torch.device("cpu")
2334+
if give_weights:
2335+
p_w = TensorDict.from_module(policy)
2336+
else:
2337+
p_w = None
2338+
col.update_policy_weights_(p_w)
2339+
elif i == 20:
2340+
if (data["action"] == 1).all():
2341+
raise RuntimeError("Failed to update buffer")
2342+
elif (data["action"] == 2).all():
2343+
raise RuntimeError("Failed to update params")
2344+
elif (data["action"] == 0).all():
2345+
raise RuntimeError("Failed to update params and buffers")
2346+
assert (data["action"] == 3).all()
2347+
finally:
2348+
col.shutdown()
2349+
del col
2350+
22912351

22922352
class TestAggregateReset:
22932353
def test_aggregate_reset_to_root(self):

torchrl/collectors/collectors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def map_weight(
232232

233233
# Create a stateless policy, then populate this copy with params on device
234234
get_original_weights = functools.partial(TensorDict.from_module, policy)
235-
with param_and_buf.to("meta").to_module(policy):
235+
# We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
236+
with param_and_buf.data.to("meta").to_module(policy):
236237
policy = deepcopy(policy)
237238

238239
param_and_buf.apply(

0 commit comments

Comments
 (0)