@@ -2233,7 +2233,9 @@ def __init__(self):
2233
2233
self .out_keys = ["action" ]
2234
2234
2235
2235
def forward (self , td ):
2236
- td ["action" ] = (self .param + self .buf ).expand (td .shape )
2236
+ td ["action" ] = (self .param + self .buf .to (self .param .device )).expand (
2237
+ td .shape
2238
+ )
2237
2239
return td
2238
2240
2239
2241
@pytest .mark .parametrize (
@@ -2288,6 +2290,64 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
2288
2290
col .shutdown ()
2289
2291
del col
2290
2292
2293
+ @pytest .mark .parametrize (
2294
+ "collector" ,
2295
+ [
2296
+ functools .partial (MultiSyncDataCollector , cat_results = "stack" ),
2297
+ MultiaSyncDataCollector ,
2298
+ ],
2299
+ )
2300
+ @pytest .mark .parametrize ("give_weights" , [True , False ])
2301
+ @pytest .mark .parametrize (
2302
+ "policy_device,env_device" ,
2303
+ [
2304
+ ["cpu" , get_default_devices ()[0 ]],
2305
+ [get_default_devices ()[0 ], "cpu" ],
2306
+ # ["cpu", "cuda:0"], # 1226: faster execution
2307
+ # ["cuda:0", "cpu"],
2308
+ # ["cuda", "cuda:0"],
2309
+ # ["cuda:0", "cuda"],
2310
+ ],
2311
+ )
2312
+ def test_param_sync_mixed_device (
2313
+ self , give_weights , collector , policy_device , env_device
2314
+ ):
2315
+ with torch .device ("cpu" ):
2316
+ policy = TestUpdateParams .Policy ()
2317
+ policy .param = nn .Parameter (policy .param .data .to (policy_device ))
2318
+ assert policy .buf .device == torch .device ("cpu" )
2319
+
2320
+ env = EnvCreator (lambda : TestUpdateParams .DummyEnv (device = env_device ))
2321
+ device = env ().device
2322
+ env = [env ]
2323
+ col = collector (
2324
+ env , policy , device = device , total_frames = 200 , frames_per_batch = 10
2325
+ )
2326
+ try :
2327
+ for i , data in enumerate (col ):
2328
+ if i == 0 :
2329
+ assert (data ["action" ] == 0 ).all ()
2330
+ # update policy
2331
+ policy .param .data += 1
2332
+ policy .buf .data += 2
2333
+ assert policy .buf .device == torch .device ("cpu" )
2334
+ if give_weights :
2335
+ p_w = TensorDict .from_module (policy )
2336
+ else :
2337
+ p_w = None
2338
+ col .update_policy_weights_ (p_w )
2339
+ elif i == 20 :
2340
+ if (data ["action" ] == 1 ).all ():
2341
+ raise RuntimeError ("Failed to update buffer" )
2342
+ elif (data ["action" ] == 2 ).all ():
2343
+ raise RuntimeError ("Failed to update params" )
2344
+ elif (data ["action" ] == 0 ).all ():
2345
+ raise RuntimeError ("Failed to update params and buffers" )
2346
+ assert (data ["action" ] == 3 ).all ()
2347
+ finally :
2348
+ col .shutdown ()
2349
+ del col
2350
+
2291
2351
2292
2352
class TestAggregateReset :
2293
2353
def test_aggregate_reset_to_root (self ):
0 commit comments