1+ import os
12from typing import Callable , Tuple
23
3- import dmc2gym
44import gym
5+ import hydra .utils
56import numpy as np
7+ import omegaconf
68import pytorch_sac
79import torch
810
9- import mbrl .env .termination_fns as termination_fns
1011import mbrl .models as models
1112import mbrl .replay_buffer as replay_buffer
1213
@@ -17,8 +18,9 @@ def collect_random_trajectories(
1718 env_dataset_test : replay_buffer .IterableReplayBuffer ,
1819 steps_to_collect : int ,
1920 val_ratio : float ,
21+ rng : np .random .RandomState ,
2022):
21- indices = np . random .permutation (steps_to_collect )
23+ indices = rng .permutation (steps_to_collect )
2224 n_train = int (steps_to_collect * (1 - val_ratio ))
2325 indices_train = set (indices [:n_train ])
2426
@@ -39,109 +41,137 @@ def collect_random_trajectories(
3941 return
4042
4143
42- def rollout_model (
43- env : gym .Env ,
44- model : models .Model ,
44+ def rollout_model_and_populate_sac_buffer (
45+ model_env : models .ModelEnv ,
4546 env_dataset : replay_buffer .BootstrapReplayBuffer ,
46- termination_fn : Callable [[np .ndarray , np .ndarray , np .ndarray ], np .ndarray ],
47- obs_shape : Tuple [int ],
48- act_shape : Tuple [int ],
49- sac_buffer_capacity : int ,
50- num_rollouts : int ,
47+ agent : pytorch_sac .SACAgent ,
48+ sac_buffer : pytorch_sac .ReplayBuffer ,
49+ sac_samples_action : bool ,
5150 rollout_horizon : int ,
5251 batch_size : int ,
53- device : torch .device ,
54- ) -> pytorch_sac .ReplayBuffer :
55- model_env = models .ModelEnv (env , model , termination_fn )
56- sac_buffer = pytorch_sac .ReplayBuffer (
57- obs_shape , act_shape , sac_buffer_capacity , device
58- )
59- for _ in range (num_rollouts ):
60- initial_obs , action , * _ = env_dataset .sample (batch_size , ensemble = False )
61- obs = model_env .reset (initial_obs_batch = initial_obs )
62- for i in range (rollout_horizon ):
63- pred_next_obs , pred_rewards , pred_dones , _ = model_env .step (action )
64- # TODO consider changing sac_buffer to vectorize this loop
65- for j in range (batch_size ):
66- sac_buffer .add (
67- obs [j ],
68- action [j ],
69- pred_rewards [j ],
70- pred_next_obs [j ],
71- pred_dones [j ],
72- pred_dones [j ],
73- )
74- obs = pred_next_obs
52+ ):
7553
76- return sac_buffer
54+ initial_obs , action , * _ = env_dataset .sample (batch_size , ensemble = False )
55+ obs = model_env .reset (initial_obs_batch = initial_obs )
56+ for i in range (rollout_horizon ):
57+ action = agent .act (obs , sample = sac_samples_action , batched = True )
58+ pred_next_obs , pred_rewards , pred_dones , _ = model_env .step (action )
59+ # TODO change sac_buffer to vectorize this loop (the batch size will be really large)
60+ for j in range (batch_size ):
61+ sac_buffer .add (
62+ obs [j ],
63+ action [j ],
64+ pred_rewards [j ],
65+ pred_next_obs [j ],
66+ pred_dones [j ],
67+ pred_dones [j ],
68+ )
69+ obs = pred_next_obs
7770
7871
79- def mbpo (
72+ def train (
8073 env : gym .Env ,
8174 termination_fn : Callable [[np .ndarray , np .ndarray , np .ndarray ], np .ndarray ],
8275 device : torch .device ,
76+ cfg : omegaconf .DictConfig ,
8377):
78+ # ------------------- Initialization -------------------
8479 obs_shape = env .observation_space .shape
8580 act_shape = env .action_space .shape
8681
87- # PARAMS TO MOVE TO A CONFIG FILE
88- ensemble_size = 7
89- val_ratio = 0.1
90- buffer_capacity = 1000
91- batch_size = 256
92- steps_to_collect = 100
93- num_epochs = 100
94- freq_train_dyn_model = 10
95- patience = 50
96- rollouts_per_step = 40
97- rollout_horizon = 15
98- sac_buffer_capacity = 10000
99-
100- # Agent
101- # agent = pytorch_sac.SACAgent()
102-
103- # Creating environment datasets
82+ cfg .agent .obs_dim = obs_shape [0 ]
83+ cfg .agent .action_dim = act_shape [0 ]
84+ cfg .agent .action_range = [
85+ float (env .action_space .low .min ()),
86+ float (env .action_space .high .max ()),
87+ ]
88+ agent = hydra .utils .instantiate (cfg .agent )
89+
90+ work_dir = os .getcwd ()
91+ logger = pytorch_sac .Logger (
92+ work_dir , save_tb = cfg .log_save_tb , log_frequency = cfg .log_frequency , agent = "sac"
93+ )
94+
95+ rng = np .random .RandomState (cfg .seed )
96+
97+ # -------------- Create initial env. dataset --------------
10498 env_dataset_train = replay_buffer .BootstrapReplayBuffer (
105- buffer_capacity , batch_size , ensemble_size , obs_shape , act_shape
99+ cfg .env_dataset_size ,
100+ cfg .dynamics_model_batch_size ,
101+ cfg .model .ensemble_size ,
102+ obs_shape ,
103+ act_shape ,
106104 )
105+ val_buffer_capacity = int (cfg .env_dataset_size * cfg .validation_ratio )
107106 env_dataset_val = replay_buffer .IterableReplayBuffer (
108- int ( buffer_capacity * val_ratio ), batch_size , obs_shape , act_shape
107+ val_buffer_capacity , cfg . dynamics_model_batch_size , obs_shape , act_shape
109108 )
109+ # TODO replace this with some exploration policy
110110 collect_random_trajectories (
111- env , env_dataset_train , env_dataset_val , steps_to_collect , val_ratio
111+ env ,
112+ env_dataset_train ,
113+ env_dataset_val ,
114+ cfg .initial_exploration_steps ,
115+ cfg .validation_ratio ,
116+ rng ,
112117 )
113118
114- # Training loop
115- model_in_size = obs_shape [0 ] + act_shape [0 ]
116- model_out_size = obs_shape [0 ] + 1
117- ensemble = models .Ensemble (
118- models .GaussianMLP , ensemble_size , model_in_size , model_out_size , device
119+ # ---------------------------------------------------------
120+ # --------------------- Training Loop ---------------------
121+ cfg .model .in_size = obs_shape [0 ] + act_shape [0 ]
122+ cfg .model .out_size = obs_shape [0 ] + 1
123+
124+ ensemble = hydra .utils .instantiate (cfg .model )
125+
126+ sac_buffer_capacity = (
127+ cfg .rollouts_per_step * cfg .rollout_horizon * cfg .rollout_batch_size
119128 )
120- for epoch in range (num_epochs ):
121- if epoch % freq_train_dyn_model == 0 :
122- train_loss , val_score = models .train_dyn_ensemble (
123- ensemble ,
124- env_dataset_train ,
125- device ,
126- dataset_val = env_dataset_val ,
127- patience = patience ,
129+
130+ updates_made = 0
131+ env_steps = 0
132+ model_env = models .ModelEnv (env , ensemble , termination_fn )
133+ for epoch in range (cfg .num_epochs ):
134+ obs = env .reset ()
135+ done = False
136+ while not done :
137+ # --------------- Env. Step and adding to model dataset -----------------
138+ action = agent .act (obs )
139+ next_obs , reward , done , _ = env .step (action )
140+ if rng .random () < cfg .validation_ratio :
141+ env_dataset_val .add (obs , action , next_obs , reward , done )
142+ else :
143+ env_dataset_train .add (obs , action , next_obs , reward , done )
144+ obs = next_obs
145+
146+ # --------------- Model Training -----------------
147+ if env_steps % cfg .freq_train_dyn_model == 0 :
148+ train_loss , val_score = models .train_dyn_ensemble (
149+ ensemble ,
150+ env_dataset_train ,
151+ device ,
152+ dataset_val = env_dataset_val ,
153+ patience = cfg .patience ,
154+ )
155+
156+ # --------------- Agent Training -----------------
157+ sac_buffer = pytorch_sac .ReplayBuffer (
158+ obs_shape , act_shape , sac_buffer_capacity , device
128159 )
160+ for _ in range (cfg .rollouts_per_step ):
161+ rollout_model_and_populate_sac_buffer (
162+ model_env ,
163+ env_dataset_train ,
164+ agent ,
165+ sac_buffer ,
166+ cfg .sac_samples_action ,
167+ cfg .rollout_horizon ,
168+ cfg .rollout_batch_size ,
169+ )
170+
171+ for _ in range (cfg .num_sac_updates_per_rollout ):
172+ agent .update (sac_buffer , logger , updates_made )
173+ updates_made += 1
174+
175+ logger .dump (updates_made , save = True )
129176
130- sac_buffer = rollout_model (
131- env ,
132- ensemble ,
133- env_dataset_train ,
134- termination_fn ,
135- obs_shape ,
136- act_shape ,
137- sac_buffer_capacity ,
138- rollouts_per_step ,
139- rollout_horizon ,
140- batch_size ,
141- device ,
142- )
143-
144-
145- if __name__ == "__main__" :
146- _env = dmc2gym .make (domain_name = "hopper" , task_name = "stand" )
147- mbpo (_env , termination_fns .hopper , torch .device ("cuda:0" ))
177+ env_steps += 1
0 commit comments