Skip to content
This repository was archived by the owner on Sep 1, 2024. It is now read-only.

Commit e51fd41

Browse files
authored
Merge pull request #1 from fairinternal/add_hydra
Added hydra configuration for MBPO
2 parents 2998853 + 4b39a60 commit e51fd41

File tree

8 files changed

+256
-288
lines changed

8 files changed

+256
-288
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*personal*
2-
.idea
2+
.idea
3+
notebooks/.ipynb_checkpoints

conf/agent/sac.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# @package _global_
2+
agent:
3+
_target_: pytorch_sac.agent.sac.SACAgent
4+
obs_dim: ??? # to be specified later
5+
action_dim: ??? # to be specified later
6+
action_range: ??? # to be specified later
7+
device: ${device}
8+
critic_cfg: ${double_q_critic}
9+
actor_cfg: ${diag_gaussian_actor}
10+
discount: 0.99
11+
init_temperature: 0.1
12+
alpha_lr: 1e-4
13+
alpha_betas: [0.9, 0.999]
14+
actor_lr: 1e-4
15+
actor_betas: [0.9, 0.999]
16+
actor_update_frequency: 1
17+
critic_lr: 1e-4
18+
critic_betas: [0.9, 0.999]
19+
critic_tau: 0.005
20+
critic_target_update_frequency: 2
21+
batch_size: 1024
22+
learnable_temperature: true
23+
target_entropy: -1
24+
25+
double_q_critic:
26+
_target_: pytorch_sac.agent.critic.DoubleQCritic
27+
obs_dim: ${agent.obs_dim}
28+
action_dim: ${agent.action_dim}
29+
hidden_dim: 1024
30+
hidden_depth: 2
31+
32+
diag_gaussian_actor:
33+
_target_: pytorch_sac.agent.actor.DiagGaussianActor
34+
obs_dim: ${agent.obs_dim}
35+
action_dim: ${agent.action_dim}
36+
hidden_depth: 2
37+
hidden_dim: 1024
38+
log_std_bounds: [-5, 2]

conf/mbpo.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
defaults:
2+
- agent: sac
3+
4+
model:
5+
_target_: mbrl.models.Ensemble
6+
ensemble_size: 7
7+
in_size: ???
8+
out_size: ???
9+
member_cfg: ${member_cfg}
10+
device: ${device}
11+
optim_lr: 0.0075 # TODO this should be moved out of the ensemble class
12+
13+
member_cfg:
14+
_target_: mbrl.models.GaussianMLP
15+
device: ${model.device}
16+
num_layers: 4
17+
in_size: ${model.in_size}
18+
out_size: ${model.out_size}
19+
hid_size: 200
20+
21+
env: "hopper--stand"
22+
23+
env_dataset_size: 1000
24+
validation_ratio: 0.1
25+
dynamics_model_batch_size: 256
26+
initial_exploration_steps: 20
27+
num_epochs: 100
28+
freq_train_dyn_model: 100
29+
patience: 50
30+
rollouts_per_step: 40
31+
rollout_horizon: 15 # TODO replace by thresholded linear
32+
rollout_batch_size: 32
33+
sac_buffer_capacity: ???
34+
sac_samples_action: true
35+
num_sac_updates_per_rollout: 100
36+
37+
seed: 0
38+
39+
device: "cuda:0"
40+
41+
log_frequency: 100
42+
log_save_tb: false
43+
44+
45+
experiment: test_exp
46+
47+
hydra:
48+
run:
49+
dir: ./exp/mbrl/${env}/${now:%Y.%m.%d}/${now:%H%M}_${experiment}

mbrl/mbpo.py

Lines changed: 116 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import os
12
from typing import Callable, Tuple
23

3-
import dmc2gym
44
import gym
5+
import hydra.utils
56
import numpy as np
7+
import omegaconf
68
import pytorch_sac
79
import torch
810

9-
import mbrl.env.termination_fns as termination_fns
1011
import mbrl.models as models
1112
import mbrl.replay_buffer as replay_buffer
1213

@@ -17,8 +18,9 @@ def collect_random_trajectories(
1718
env_dataset_test: replay_buffer.IterableReplayBuffer,
1819
steps_to_collect: int,
1920
val_ratio: float,
21+
rng: np.random.RandomState,
2022
):
21-
indices = np.random.permutation(steps_to_collect)
23+
indices = rng.permutation(steps_to_collect)
2224
n_train = int(steps_to_collect * (1 - val_ratio))
2325
indices_train = set(indices[:n_train])
2426

@@ -39,109 +41,137 @@ def collect_random_trajectories(
3941
return
4042

4143

42-
def rollout_model(
43-
env: gym.Env,
44-
model: models.Model,
44+
def rollout_model_and_populate_sac_buffer(
45+
model_env: models.ModelEnv,
4546
env_dataset: replay_buffer.BootstrapReplayBuffer,
46-
termination_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
47-
obs_shape: Tuple[int],
48-
act_shape: Tuple[int],
49-
sac_buffer_capacity: int,
50-
num_rollouts: int,
47+
agent: pytorch_sac.SACAgent,
48+
sac_buffer: pytorch_sac.ReplayBuffer,
49+
sac_samples_action: bool,
5150
rollout_horizon: int,
5251
batch_size: int,
53-
device: torch.device,
54-
) -> pytorch_sac.ReplayBuffer:
55-
model_env = models.ModelEnv(env, model, termination_fn)
56-
sac_buffer = pytorch_sac.ReplayBuffer(
57-
obs_shape, act_shape, sac_buffer_capacity, device
58-
)
59-
for _ in range(num_rollouts):
60-
initial_obs, action, *_ = env_dataset.sample(batch_size, ensemble=False)
61-
obs = model_env.reset(initial_obs_batch=initial_obs)
62-
for i in range(rollout_horizon):
63-
pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action)
64-
# TODO consider changing sac_buffer to vectorize this loop
65-
for j in range(batch_size):
66-
sac_buffer.add(
67-
obs[j],
68-
action[j],
69-
pred_rewards[j],
70-
pred_next_obs[j],
71-
pred_dones[j],
72-
pred_dones[j],
73-
)
74-
obs = pred_next_obs
52+
):
7553

76-
return sac_buffer
54+
initial_obs, action, *_ = env_dataset.sample(batch_size, ensemble=False)
55+
obs = model_env.reset(initial_obs_batch=initial_obs)
56+
for i in range(rollout_horizon):
57+
action = agent.act(obs, sample=sac_samples_action, batched=True)
58+
pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action)
59+
# TODO change sac_buffer to vectorize this loop (the batch size will be really large)
60+
for j in range(batch_size):
61+
sac_buffer.add(
62+
obs[j],
63+
action[j],
64+
pred_rewards[j],
65+
pred_next_obs[j],
66+
pred_dones[j],
67+
pred_dones[j],
68+
)
69+
obs = pred_next_obs
7770

7871

79-
def mbpo(
72+
def train(
8073
env: gym.Env,
8174
termination_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
8275
device: torch.device,
76+
cfg: omegaconf.DictConfig,
8377
):
78+
# ------------------- Initialization -------------------
8479
obs_shape = env.observation_space.shape
8580
act_shape = env.action_space.shape
8681

87-
# PARAMS TO MOVE TO A CONFIG FILE
88-
ensemble_size = 7
89-
val_ratio = 0.1
90-
buffer_capacity = 1000
91-
batch_size = 256
92-
steps_to_collect = 100
93-
num_epochs = 100
94-
freq_train_dyn_model = 10
95-
patience = 50
96-
rollouts_per_step = 40
97-
rollout_horizon = 15
98-
sac_buffer_capacity = 10000
99-
100-
# Agent
101-
# agent = pytorch_sac.SACAgent()
102-
103-
# Creating environment datasets
82+
cfg.agent.obs_dim = obs_shape[0]
83+
cfg.agent.action_dim = act_shape[0]
84+
cfg.agent.action_range = [
85+
float(env.action_space.low.min()),
86+
float(env.action_space.high.max()),
87+
]
88+
agent = hydra.utils.instantiate(cfg.agent)
89+
90+
work_dir = os.getcwd()
91+
logger = pytorch_sac.Logger(
92+
work_dir, save_tb=cfg.log_save_tb, log_frequency=cfg.log_frequency, agent="sac"
93+
)
94+
95+
rng = np.random.RandomState(cfg.seed)
96+
97+
# -------------- Create initial env. dataset --------------
10498
env_dataset_train = replay_buffer.BootstrapReplayBuffer(
105-
buffer_capacity, batch_size, ensemble_size, obs_shape, act_shape
99+
cfg.env_dataset_size,
100+
cfg.dynamics_model_batch_size,
101+
cfg.model.ensemble_size,
102+
obs_shape,
103+
act_shape,
106104
)
105+
val_buffer_capacity = int(cfg.env_dataset_size * cfg.validation_ratio)
107106
env_dataset_val = replay_buffer.IterableReplayBuffer(
108-
int(buffer_capacity * val_ratio), batch_size, obs_shape, act_shape
107+
val_buffer_capacity, cfg.dynamics_model_batch_size, obs_shape, act_shape
109108
)
109+
# TODO replace this with some exploration policy
110110
collect_random_trajectories(
111-
env, env_dataset_train, env_dataset_val, steps_to_collect, val_ratio
111+
env,
112+
env_dataset_train,
113+
env_dataset_val,
114+
cfg.initial_exploration_steps,
115+
cfg.validation_ratio,
116+
rng,
112117
)
113118

114-
# Training loop
115-
model_in_size = obs_shape[0] + act_shape[0]
116-
model_out_size = obs_shape[0] + 1
117-
ensemble = models.Ensemble(
118-
models.GaussianMLP, ensemble_size, model_in_size, model_out_size, device
119+
# ---------------------------------------------------------
120+
# --------------------- Training Loop ---------------------
121+
cfg.model.in_size = obs_shape[0] + act_shape[0]
122+
cfg.model.out_size = obs_shape[0] + 1
123+
124+
ensemble = hydra.utils.instantiate(cfg.model)
125+
126+
sac_buffer_capacity = (
127+
cfg.rollouts_per_step * cfg.rollout_horizon * cfg.rollout_batch_size
119128
)
120-
for epoch in range(num_epochs):
121-
if epoch % freq_train_dyn_model == 0:
122-
train_loss, val_score = models.train_dyn_ensemble(
123-
ensemble,
124-
env_dataset_train,
125-
device,
126-
dataset_val=env_dataset_val,
127-
patience=patience,
129+
130+
updates_made = 0
131+
env_steps = 0
132+
model_env = models.ModelEnv(env, ensemble, termination_fn)
133+
for epoch in range(cfg.num_epochs):
134+
obs = env.reset()
135+
done = False
136+
while not done:
137+
# --------------- Env. Step and adding to model dataset -----------------
138+
action = agent.act(obs)
139+
next_obs, reward, done, _ = env.step(action)
140+
if rng.random() < cfg.validation_ratio:
141+
env_dataset_val.add(obs, action, next_obs, reward, done)
142+
else:
143+
env_dataset_train.add(obs, action, next_obs, reward, done)
144+
obs = next_obs
145+
146+
# --------------- Model Training -----------------
147+
if env_steps % cfg.freq_train_dyn_model == 0:
148+
train_loss, val_score = models.train_dyn_ensemble(
149+
ensemble,
150+
env_dataset_train,
151+
device,
152+
dataset_val=env_dataset_val,
153+
patience=cfg.patience,
154+
)
155+
156+
# --------------- Agent Training -----------------
157+
sac_buffer = pytorch_sac.ReplayBuffer(
158+
obs_shape, act_shape, sac_buffer_capacity, device
128159
)
160+
for _ in range(cfg.rollouts_per_step):
161+
rollout_model_and_populate_sac_buffer(
162+
model_env,
163+
env_dataset_train,
164+
agent,
165+
sac_buffer,
166+
cfg.sac_samples_action,
167+
cfg.rollout_horizon,
168+
cfg.rollout_batch_size,
169+
)
170+
171+
for _ in range(cfg.num_sac_updates_per_rollout):
172+
agent.update(sac_buffer, logger, updates_made)
173+
updates_made += 1
174+
175+
logger.dump(updates_made, save=True)
129176

130-
sac_buffer = rollout_model(
131-
env,
132-
ensemble,
133-
env_dataset_train,
134-
termination_fn,
135-
obs_shape,
136-
act_shape,
137-
sac_buffer_capacity,
138-
rollouts_per_step,
139-
rollout_horizon,
140-
batch_size,
141-
device,
142-
)
143-
144-
145-
if __name__ == "__main__":
146-
_env = dmc2gym.make(domain_name="hopper", task_name="stand")
147-
mbpo(_env, termination_fns.hopper, torch.device("cuda:0"))
177+
env_steps += 1

mbrl/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from typing import Tuple, List, Optional, Type, Dict, Sequence, Union
44

55
import gym
6+
import hydra.utils
67
import numpy as np
8+
import omegaconf
79
import torch
810
from torch import nn as nn, optim as optim
911
from torch.nn import functional as F
@@ -87,24 +89,22 @@ def eval_score(self, model_in: torch.Tensor, target: torch.Tensor) -> float:
8789
class Ensemble(Model):
8890
def __init__(
8991
self,
90-
cls: Type[Model],
91-
num_members: int,
92+
ensemble_size: int,
9293
in_size: int,
9394
out_size: int,
9495
device: torch.device,
95-
*model_args,
96+
member_cfg: omegaconf.DictConfig,
9697
optim_lr: float = 0.0075,
97-
seed: Optional[int] = None,
98-
**model_kwargs,
9998
):
10099
super().__init__(in_size, out_size, device)
101100
self.members = []
102101
self.optimizers = []
103-
for i in range(num_members):
104-
model = cls(in_size, out_size, device, *model_args, **model_kwargs)
102+
for i in range(ensemble_size):
103+
model = hydra.utils.instantiate(member_cfg)
104+
# model = member_cls(in_size, out_size, device, *model_args, **model_kwargs)
105105
self.members.append(model.to(device))
106106
self.optimizers.append(optim.Adam(model.parameters(), lr=optim_lr))
107-
self.rng = np.random.RandomState(seed)
107+
self.rng = np.random.RandomState()
108108

109109
def __len__(self):
110110
return len(self.members)
@@ -261,7 +261,7 @@ def step(self, actions: np.ndarray):
261261
model_in = torch.from_numpy(
262262
np.concatenate([self._current_obs, actions], axis=1)
263263
).to(self.model.device)
264-
model_out = self.model(model_in).cpu().numpy()[0]
264+
model_out = self.model(model_in)[0].cpu().numpy()
265265
next_observs = model_out[:, :-1]
266266
rewards = model_out[:, -1]
267267
dones = self.termination_fn(actions, next_observs)

0 commit comments

Comments
 (0)