Skip to content

Commit 539c215

Browse files
author
Vincent Moens
committed
[Example] Using Collector's device args
ghstack-source-id: 9aec8da Pull Request resolved: #2705
1 parent 1d45117 commit 539c215

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Using the SyncDataCollector with Different Device Combinations
8+
==============================================================
9+
10+
TorchRL's SyncDataCollector allows you to specify the devices on which different components of the data collection
11+
process are executed. This example demonstrates how to use the collector with various device combinations.
12+
13+
14+
Understanding Device Precedence
15+
-------------------------------
16+
17+
When creating a SyncDataCollector, you can specify the devices for the environment (env_device), policy (policy_device),
18+
and data collection (device). The device argument serves as a default value for any unspecified devices. However, if you
19+
provide env_device or policy_device, they take precedence over the device argument for their respective components.
20+
21+
For example:
22+
23+
- If you set device="cuda", all components will be executed on the CUDA device unless you specify otherwise.
24+
- If you set env_device="cpu" and device="cuda", the environment will be executed on the CPU, while the policy and data
25+
collection will be executed on the CUDA device.
26+
27+
Keeping Policy Parameters in Sync
28+
---------------------------------
29+
30+
When using a policy with buffers or other attributes that are not automatically updated when moving the policy's
31+
parameters to a different device, it's essential to keep the policy's parameters in sync between the main workspace and
32+
the collector.
33+
34+
To do this, call update_policy_weights_() anytime the policy's parameters (and buffers!) are updated. This ensures that
35+
the policy used by the collector has the same parameters as the policy in the main workspace.
36+
37+
Example Use Cases
38+
-----------------
39+
40+
This script demonstrates the SyncDataCollector with the following device combinations:
41+
42+
- Collector on CUDA
43+
- Collector on CPU
44+
- Mixed collector: policy on CUDA, env untouched (ie, unmarked CPU, env.device == None)
45+
- Mixed collector: policy on CUDA, env on CPU (env.device == "cpu")
46+
- Mixed collector: all on CUDA, except env on CPU.
47+
48+
For each configuration, we run a DQN algorithm and check that it converges.
49+
By following this example, you can learn how to use the SyncDataCollector with different device combinations and ensure
50+
that your policy's parameters are kept in sync.
51+
52+
"""
53+
54+
import logging
55+
import time
56+
57+
import torch.cuda
58+
import torch.nn as nn
59+
import torch.optim as optim
60+
61+
from tensordict.nn import TensorDictSequential as TDSeq
62+
63+
from torchrl.collectors import SyncDataCollector
64+
from torchrl.data import LazyTensorStorage, ReplayBuffer
65+
from torchrl.envs import Compose, GymEnv, RewardSum, StepCounter, TransformedEnv
66+
from torchrl.modules import EGreedyModule, QValueActor
67+
from torchrl.objectives import DQNLoss, SoftUpdate
68+
69+
70+
logging.basicConfig(level=logging.INFO)
71+
my_logger = logging.getLogger(__name__)
72+
73+
ENV_NAME = "CartPole-v1"
74+
75+
INIT_RND_STEPS = 5_120
76+
FRAMES_PER_BATCH = 128
77+
BUFFER_SIZE = 100_000
78+
79+
GAMMA = 0.98
80+
OPTIM_STEPS = 10
81+
BATCH_SIZE = 128
82+
83+
SOFTU_EPS = 0.99
84+
LR = 0.02
85+
86+
87+
class Net(nn.Module):
88+
def __init__(self, obs_size: int, n_actions: int) -> None:
89+
super().__init__()
90+
self.net = nn.Sequential(
91+
nn.Linear(obs_size, 128),
92+
nn.ReLU(),
93+
nn.Linear(128, n_actions),
94+
)
95+
96+
def forward(self, x):
97+
orig_shape_unbatched = len(x.shape) == 1
98+
if orig_shape_unbatched:
99+
x = x.unsqueeze(0)
100+
101+
out = self.net(x)
102+
103+
if orig_shape_unbatched:
104+
out = out.squeeze(0)
105+
return out
106+
107+
108+
def make_env(env_name: str):
109+
return TransformedEnv(GymEnv(env_name), Compose(StepCounter(), RewardSum()))
110+
111+
112+
if __name__ == "__main__":
113+
114+
for env_device, policy_device, device in (
115+
(None, None, "cuda"),
116+
(None, None, "cpu"),
117+
(None, "cuda", None),
118+
("cpu", "cuda", None),
119+
("cpu", None, "cuda"),
120+
# These configs don't run because the collector needs to know that the policy is on CUDA
121+
# This is not true for the env which has specs that are associated with a device, we can
122+
# automatically transfer the data. The policy does not, in general, have a spec indicating
123+
# what the input and output devices are, so this must be told to the collector.
124+
# (None, None, None),
125+
# ("cpu", None, None),
126+
):
127+
torch.manual_seed(0)
128+
torch.cuda.manual_seed(0)
129+
130+
env = make_env(ENV_NAME)
131+
env.set_seed(0)
132+
133+
n_obs = env.observation_spec["observation"].shape[-1]
134+
n_act = env.action_spec.shape[-1]
135+
136+
net = Net(n_obs, n_act).to(device="cuda:0")
137+
agent = QValueActor(net, spec=env.action_spec.to("cuda:0"))
138+
139+
# policy_explore has buffers on CPU - we will need to call collector.update_policy_weights_()
140+
# to sync them during data collection.
141+
policy_explore = EGreedyModule(env.action_spec)
142+
agent_explore = TDSeq(agent, policy_explore)
143+
144+
collector = SyncDataCollector(
145+
env,
146+
agent_explore,
147+
frames_per_batch=FRAMES_PER_BATCH,
148+
init_random_frames=INIT_RND_STEPS,
149+
device=device,
150+
env_device=env_device,
151+
policy_device=policy_device,
152+
)
153+
exp_buffer = ReplayBuffer(
154+
storage=LazyTensorStorage(BUFFER_SIZE, device="cuda:0")
155+
)
156+
157+
loss = DQNLoss(
158+
value_network=agent, action_space=env.action_spec, delay_value=True
159+
)
160+
loss.make_value_estimator(gamma=GAMMA)
161+
target_updater = SoftUpdate(loss, eps=SOFTU_EPS)
162+
optimizer = optim.Adam(loss.parameters(), lr=LR)
163+
164+
total_count = 0
165+
total_episodes = 0
166+
t0 = time.time()
167+
for i, data in enumerate(collector):
168+
# Check the data devices
169+
if device is None:
170+
assert data["action"].device == torch.device("cuda:0")
171+
assert data["observation"].device == torch.device("cpu")
172+
assert data["done"].device == torch.device("cpu")
173+
elif device == "cpu":
174+
assert data["action"].device == torch.device("cpu")
175+
assert data["observation"].device == torch.device("cpu")
176+
assert data["done"].device == torch.device("cpu")
177+
else:
178+
assert data["action"].device == torch.device("cuda:0")
179+
assert data["observation"].device == torch.device("cuda:0")
180+
assert data["done"].device == torch.device("cuda:0")
181+
182+
exp_buffer.extend(data)
183+
max_length = exp_buffer["next", "step_count"].max()
184+
max_reward = exp_buffer["next", "episode_reward"].max()
185+
if len(exp_buffer) > INIT_RND_STEPS:
186+
for _ in range(OPTIM_STEPS):
187+
optimizer.zero_grad()
188+
sample = exp_buffer.sample(batch_size=BATCH_SIZE)
189+
190+
loss_vals = loss(sample)
191+
loss_vals["loss"].backward()
192+
optimizer.step()
193+
194+
agent_explore[1].step(data.numel())
195+
target_updater.step()
196+
197+
total_count += data.numel()
198+
total_episodes += data["next", "done"].sum()
199+
200+
if i % 10 == 0:
201+
my_logger.info(
202+
f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}."
203+
)
204+
collector.update_policy_weights_()
205+
if max_length > 200:
206+
t1 = time.time()
207+
my_logger.info(f"SOLVED in {t1 - t0}s!! MaxLen: {max_length}!")
208+
my_logger.info(f"With {max_reward} Reward!")
209+
my_logger.info(f"In {total_episodes} Episodes!")
210+
my_logger.info(f"Using devices {(env_device, policy_device, device)}")
211+
break
212+
else:
213+
raise RuntimeError(
214+
f"Failed to converge with config {(env_device, policy_device, device)}"
215+
)

0 commit comments

Comments
 (0)