Skip to content

Commit 0c5becd

Browse files
authored
Add monitor keyword argument parameter (#304)
* Allow custom monitor kwargso * Add test and doc * Update changelog * Fix test * Update README * Allow `python -m rl_zoo3.cli` to be called directly
1 parent b826704 commit 0c5becd

File tree

6 files changed

+39
-5
lines changed

6 files changed

+39
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
## Release 1.7.0a0 (WIP)
1+
## Release 1.7.0a1 (WIP)
22

33
### Breaking Changes
44

55
### New Features
66
- Specifying custom policies in yaml file is now supported (@Rick-v-E)
7+
- Added ``monitor_kwargs`` parameter
78

89
### Bug fixes
10+
- Allow `python -m rl_zoo3.cli` to be called directly
911

1012
### Documentation
1113

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,23 @@ env_wrapper:
296296
297297
Note that you can easily specify parameters too.
298298
299+
By default, the environment is wrapped with a `Monitor` wrapper to record episode statistics.
300+
You can specify arguments to it using `monitor_kwargs` parameter to log additional data.
301+
That data *must* be present in the info dictionary at the last step of each episode.
302+
303+
For instance, for recording success with goal envs (e.g. `FetchReach-v1`):
304+
305+
```yaml
306+
monitor_kwargs: dict(info_keywords=('is_success',))
307+
```
308+
309+
or recording final x position with `Ant-v3`:
310+
```yaml
311+
monitor_kwargs: dict(info_keywords=('x_position',))
312+
```
313+
314+
Note: for known `GoalEnv` like `FetchReach`, `info_keywords=('is_success',)` is actually the default.
315+
299316
## VecEnvWrapper
300317

301318
You can specify which `VecEnvWrapper` to use in the config, the same way as for env wrappers (see above), using the `vec_env_wrapper` key:

rl_zoo3/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ def main():
2020
if script_name not in known_scripts.keys():
2121
raise ValueError(f"The script {script_name} is unknown, please use one of {known_scripts.keys()}")
2222
known_scripts[script_name]()
23+
24+
25+
if __name__ == "__main__":
26+
main()

rl_zoo3/exp_manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(
135135
self.n_envs = 1 # it will be updated when reading hyperparams
136136
self.n_actions = None # For DDPG/TD3 action noise objects
137137
self._hyperparams = {}
138+
self.monitor_kwargs = {}
138139

139140
self.trained_agent = trained_agent
140141
self.continue_training = trained_agent.endswith(".zip") and os.path.isfile(trained_agent)
@@ -381,6 +382,14 @@ def _preprocess_hyperparams(
381382
if kwargs_key in hyperparams.keys() and isinstance(hyperparams[kwargs_key], str):
382383
hyperparams[kwargs_key] = eval(hyperparams[kwargs_key])
383384

385+
# Preprocess monitor kwargs
386+
if "monitor_kwargs" in hyperparams.keys():
387+
self.monitor_kwargs = hyperparams["monitor_kwargs"]
388+
# Convert str to python code
389+
if isinstance(self.monitor_kwargs, str):
390+
self.monitor_kwargs = eval(self.monitor_kwargs)
391+
del hyperparams["monitor_kwargs"]
392+
384393
# Delete keys so the dict can be pass to the model constructor
385394
if "n_envs" in hyperparams.keys():
386395
del hyperparams["n_envs"]
@@ -550,14 +559,14 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
550559
# Do not log eval env (issue with writing the same file)
551560
log_dir = None if eval_env or no_log else self.save_path
552561

553-
monitor_kwargs = {}
554562
# Special case for GoalEnvs: log success rate too
555563
if (
556564
"Neck" in self.env_name.gym_id
557565
or self.is_robotics_env(self.env_name.gym_id)
558566
or "parking-v0" in self.env_name.gym_id
567+
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
559568
):
560-
monitor_kwargs = dict(info_keywords=("is_success",))
569+
self.monitor_kwargs = dict(info_keywords=("is_success",))
561570

562571
# On most env, SubprocVecEnv does not help and is quite memory hungry
563572
# therefore we use DummyVecEnv by default
@@ -570,7 +579,7 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
570579
wrapper_class=self.env_wrapper,
571580
vec_env_cls=self.vec_env_class,
572581
vec_env_kwargs=self.vec_env_kwargs,
573-
monitor_kwargs=monitor_kwargs,
582+
monitor_kwargs=self.monitor_kwargs,
574583
)
575584

576585
if self.vec_env_wrapper is not None:

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.7.0a0
1+
1.7.0a1

tests/test_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def test_parallel_train(tmp_path):
9696
"--log-folder",
9797
tmp_path,
9898
"-params",
99+
# Test custom argument for the monitor too
100+
"monitor_kwargs:'dict(info_keywords=(\"TimeLimit.truncated\",))'",
99101
"callback:'rl_zoo3.callbacks.ParallelTrainCallback'",
100102
]
101103

0 commit comments

Comments
 (0)