Skip to content

Commit b826704

Browse files
Rick-v-Earaffin
andauthored
Add support for custom policy in yaml file (#303)
* Add support for custom policy in yaml file * Update changelog * Add test for custom policies Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
1 parent 3663174 commit b826704

File tree

6 files changed

+53
-11
lines changed

6 files changed

+53
-11
lines changed

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
## Release 1.7.0a0 (WIP)
2+
3+
### Breaking Changes
4+
5+
### New Features
6+
- Specifying custom policies in yaml file is now supported (@Rick-v-E)
7+
8+
### Bug fixes
9+
10+
### Documentation
11+
12+
### Other
13+
14+
115
## Release 1.6.3 (2022-10-13)
216

317
### Breaking Changes
@@ -13,6 +27,7 @@
1327
### Other
1428
- Used issue forms instead of issue templates
1529

30+
1631
## Release 1.6.2.post2 (2022-10-10)
1732

1833
### Breaking Changes

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ Specify a different activation function for the network:
206206
policy_kwargs: "dict(activation_fn=nn.ReLU)"
207207
```
208208
209+
For a custom policy:
210+
211+
```yaml
212+
policy: my_package.MyCustomPolicy # for instance stable_baselines3.ppo.MlpPolicy
213+
```
214+
215+
209216
## Hyperparameter Tuning
210217
211218
We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.

rl_zoo3/exp_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error
4848
from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback
4949
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
50-
from rl_zoo3.utils import ALGOS, get_callback_list, get_latest_run_id, get_wrapper_class, linear_schedule
50+
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule
5151

5252

5353
class ExperimentManager:
@@ -390,6 +390,10 @@ def _preprocess_hyperparams(
390390
self.frame_stack = hyperparams["frame_stack"]
391391
del hyperparams["frame_stack"]
392392

393+
# import the policy when using a custom policy
394+
if "policy" in hyperparams and "." in hyperparams["policy"]:
395+
hyperparams["policy"] = get_class_by_name(hyperparams["policy"])
396+
393397
# obtain a class object from a wrapper name string in hyperparams
394398
# and delete the entry
395399
env_wrapper = get_wrapper_class(hyperparams)

rl_zoo3/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import glob
33
import importlib
44
import os
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
66

77
import gym
88
import stable_baselines3 as sb3 # noqa: F401
@@ -118,6 +118,26 @@ def wrap_env(env: gym.Env) -> gym.Env:
118118
return None
119119

120120

121+
def get_class_by_name(name: str) -> Type:
122+
"""
123+
Imports and returns a class given the name, e.g. passing
124+
'stable_baselines3.common.callbacks.CheckpointCallback' returns the
125+
CheckpointCallback class.
126+
127+
:param name:
128+
:return:
129+
"""
130+
131+
def get_module_name(name: str) -> str:
132+
return ".".join(name.split(".")[:-1])
133+
134+
def get_class_name(name: str) -> str:
135+
return name.split(".")[-1]
136+
137+
module = importlib.import_module(get_module_name(name))
138+
return getattr(module, get_class_name(name))
139+
140+
121141
def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]:
122142
"""
123143
Get one or more Callback class specified as a hyper-parameter
@@ -135,12 +155,6 @@ def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]:
135155
:return:
136156
"""
137157

138-
def get_module_name(callback_name):
139-
return ".".join(callback_name.split(".")[:-1])
140-
141-
def get_class_name(callback_name):
142-
return callback_name.split(".")[-1]
143-
144158
callbacks = []
145159

146160
if "callback" in hyperparams.keys():
@@ -168,8 +182,8 @@ def get_class_name(callback_name):
168182
kwargs = callback_dict[callback_name]
169183
else:
170184
kwargs = {}
171-
callback_module = importlib.import_module(get_module_name(callback_name))
172-
callback_class = getattr(callback_module, get_class_name(callback_name))
185+
186+
callback_class = get_class_by_name(callback_name)
173187
callbacks.append(callback_class(**kwargs))
174188

175189
return callbacks

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.6.3
1+
1.7.0a0

tests/test_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def test_custom_yaml(tmp_path):
121121
"n_steps:50",
122122
"n_epochs:2",
123123
"batch_size:4",
124+
# Test custom policy
125+
"policy:'stable_baselines3.ppo.MlpPolicy'",
124126
]
125127

126128
return_code = subprocess.call(["python", "train.py"] + args)

0 commit comments

Comments
 (0)