Skip to content

Commit 1f9cfcf

Browse files
ernestumaraffin
andauthored
Allow python script for hyperparameter configuration (#318)
* Allow loading configuration from python package in addition to yaml files. * Add an example python configuration file. * Add a test for training with the example python config file. * Formatting fixes. * Flake8 6 doesn't support inline comments * Allow python scripts * Remove debug print * Add python config file usage example to README.md * Update CHANGELOG.md * Fix paths in python config example and add variant where we specify the python file instead of package. * Update README Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
1 parent 168c34e commit 1f9cfcf

File tree

9 files changed

+119
-21
lines changed

9 files changed

+119
-21
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
## Release 1.7.0a2 (WIP)
22

33
### Breaking Changes
4+
- `--yaml-file` argument was renamed to `-conf` (`--conf-file`) as now python file are supported too
45

56
### New Features
67
- Specifying custom policies in yaml file is now supported (@Rick-v-E)
78
- Added ``monitor_kwargs`` parameter
89
- Handle the `env_kwargs` of `render:True` under the hood for panda-gym v1 envs in `enjoy` replay to match visualzation behavior of other envs
10+
- Added support for python config file
911

1012
### Bug fixes
1113
- Allow `python -m rl_zoo3.cli` to be called directly

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/
1+
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py
22

33
# Run pytest and coverage report
44
pytest:

README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,23 @@ python train.py --algo algo_name --env env_id
6161
```
6262
You can use `-P` (`--progress`) option to display a progress bar.
6363

64-
Using a custom yaml file (which contains a `env_id` entry):
64+
Using a custom config file when it is a yaml file with a which contains a `env_id` entry:
6565
```
66-
python train.py --algo algo_name --env env_id --yaml-file my_yaml.yml
66+
python train.py --algo algo_name --env env_id --conf-file my_yaml.yml
6767
```
6868

69+
You can also use a python file that contains a dictionary called `hyperparams` with an entry for each `env_id`.
70+
(see `hyperparams/python/ppo_config_example.py` for an example)
71+
```bash
72+
# You can pass a path to a python file
73+
python train.py --algo ppo --env MountainCarContinuous-v0 --conf-file hyperparams/python/ppo_config_example.py
74+
# Or pass a path to a file from a module (for instance my_package.my_file
75+
python train.py --algo ppo --env MountainCarContinuous-v0 --conf-file hyperparams.python.ppo_config_example
76+
```
77+
The advantage of this approach is that you can specify arbitrary python dictionaries
78+
and ensure that all their dependencies are imported in the config file itself.
79+
80+
6981
For example (with tensorboard support):
7082
```
7183
python train.py --algo ppo --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/
@@ -139,7 +151,7 @@ Remark: plotting with the `--rliable` option is usually slow as confidence inter
139151

140152
## Custom Environment
141153

142-
The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--yaml-file` argument).
154+
The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--conf-file` argument).
143155

144156
## Enjoy a Trained Agent
145157

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""This file just serves as an example on how to configure the zoo
2+
using python scripts instead of yaml files."""
3+
import torch
4+
5+
hyperparams = {
6+
"MountainCarContinuous-v0": dict(
7+
env_wrapper=[{"gym.wrappers.TimeLimit": {"max_episode_steps": 100}}],
8+
normalize=True,
9+
n_envs=1,
10+
n_timesteps=20000.0,
11+
policy="MlpPolicy",
12+
batch_size=8,
13+
n_steps=8,
14+
gamma=0.9999,
15+
learning_rate=7.77e-05,
16+
ent_coef=0.00429,
17+
clip_range=0.1,
18+
n_epochs=2,
19+
gae_lambda=0.9,
20+
max_grad_norm=5,
21+
vf_coef=0.19,
22+
use_sde=True,
23+
policy_kwargs=dict(
24+
log_std_init=-3.29,
25+
ortho_init=False,
26+
activation_fn=torch.nn.ReLU,
27+
),
28+
)
29+
}

rl_zoo3/exp_manager.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import importlib
23
import os
34
import pickle as pkl
45
import time
@@ -93,7 +94,7 @@ def __init__(
9394
n_eval_envs: int = 1,
9495
no_optim_plots: bool = False,
9596
device: Union[th.device, str] = "auto",
96-
yaml_file: Optional[str] = None,
97+
config: Optional[str] = None,
9798
show_progress: bool = False,
9899
):
99100
super().__init__()
@@ -108,7 +109,7 @@ def __init__(
108109
# Take the root folder
109110
default_path = Path(__file__).parent.parent
110111

111-
self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml")
112+
self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
112113
self.env_kwargs = {} if env_kwargs is None else env_kwargs
113114
self.n_timesteps = n_timesteps
114115
self.normalize = False
@@ -281,16 +282,28 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
281282
print(f"Log path: {self.save_path}")
282283

283284
def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
284-
# Load hyperparameters from yaml file
285-
print(f"Loading hyperparameters from: {self.yaml_file}")
286-
with open(self.yaml_file) as f:
287-
hyperparams_dict = yaml.safe_load(f)
288-
if self.env_name.gym_id in list(hyperparams_dict.keys()):
289-
hyperparams = hyperparams_dict[self.env_name.gym_id]
290-
elif self._is_atari:
291-
hyperparams = hyperparams_dict["atari"]
292-
else:
293-
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")
285+
print(f"Loading hyperparameters from: {self.config}")
286+
287+
if self.config.endswith(".yml") or self.config.endswith(".yaml"):
288+
# Load hyperparameters from yaml file
289+
with open(self.config) as f:
290+
hyperparams_dict = yaml.safe_load(f)
291+
elif self.config.endswith(".py"):
292+
global_variables = {}
293+
# Load hyperparameters from python file
294+
exec(Path(self.config).read_text(), global_variables)
295+
hyperparams_dict = global_variables["hyperparams"]
296+
else:
297+
# Load hyperparameters from python package
298+
hyperparams_dict = importlib.import_module(self.config).hyperparams
299+
# raise ValueError(f"Unsupported config file format: {self.config}")
300+
301+
if self.env_name.gym_id in list(hyperparams_dict.keys()):
302+
hyperparams = hyperparams_dict[self.env_name.gym_id]
303+
elif self._is_atari:
304+
hyperparams = hyperparams_dict["atari"]
305+
else:
306+
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id} in {self.config}")
294307

295308
if self.custom_hyperparams is not None:
296309
# Overwrite hyperparams if needed
@@ -336,6 +349,10 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
336349
self.normalize_kwargs = eval(self.normalize)
337350
self.normalize = True
338351

352+
if isinstance(self.normalize, dict):
353+
self.normalize_kwargs = self.normalize
354+
self.normalize = True
355+
339356
# Use the same discount factor as for the algorithm
340357
if "gamma" in hyperparams:
341358
self.normalize_kwargs["gamma"] = hyperparams["gamma"]

rl_zoo3/train.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,19 @@ def train():
122122
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
123123
)
124124
parser.add_argument(
125-
"-yaml", "--yaml-file", type=str, default=None, help="Custom yaml file from which the hyperparameters will be loaded"
125+
"-conf",
126+
"--conf-file",
127+
type=str,
128+
default=None,
129+
help="Custom yaml file or python package from which the hyperparameters will be loaded."
130+
"We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment.",
131+
)
132+
parser.add_argument(
133+
"-yaml",
134+
"--yaml-file",
135+
type=str,
136+
default=None,
137+
help="This parameter is deprecated, please use `--conf-file` instead",
126138
)
127139
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
128140
parser.add_argument(
@@ -150,6 +162,11 @@ def train():
150162
env_id = args.env
151163
registered_envs = set(gym.envs.registry.env_specs.keys()) # pytype: disable=module-attr
152164

165+
if args.yaml_file is not None:
166+
raise ValueError(
167+
"The`--yaml-file` parameter is deprecated and will be removed in RL Zoo3 v1.8, please use `--conf-file` instead",
168+
)
169+
153170
# If the environment is not found, suggest the closest match
154171
if env_id not in registered_envs:
155172
try:
@@ -234,7 +251,7 @@ def train():
234251
n_eval_envs=args.n_eval_envs,
235252
no_optim_plots=args.no_optim_plots,
236253
device=args.device,
237-
yaml_file=args.yaml_file,
254+
config=args.conf_file,
238255
show_progress=args.progress,
239256
)
240257

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.7.0a2
1+
1.7.0a3

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ markers =
1616
inputs = .
1717

1818
[flake8]
19-
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
19+
# line breaks before and after binary operators
20+
ignore = W503,W504,E203,E231
2021
# Ignore import not used when aliases are defined
2122
per-file-ignores =
2223
./rl_zoo3/__init__.py:F401

tests/test_train.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test_custom_yaml(tmp_path):
116116
"CartPole-v1",
117117
"--log-folder",
118118
tmp_path,
119-
"-yaml",
119+
"-conf",
120120
"hyperparams/a2c.yml",
121121
"-params",
122122
"n_envs:2",
@@ -129,3 +129,23 @@ def test_custom_yaml(tmp_path):
129129

130130
return_code = subprocess.call(["python", "train.py"] + args)
131131
_assert_eq(return_code, 0)
132+
133+
134+
@pytest.mark.parametrize("config_file", ["hyperparams.python.ppo_config_example", "hyperparams/python/ppo_config_example.py"])
135+
def test_python_config_file(tmp_path, config_file):
136+
# Use the example python config file for training
137+
args = [
138+
"-n",
139+
str(N_STEPS),
140+
"--algo",
141+
"ppo",
142+
"--env",
143+
"MountainCarContinuous-v0",
144+
"--log-folder",
145+
tmp_path,
146+
"-conf",
147+
config_file,
148+
]
149+
150+
return_code = subprocess.call(["python", "train.py"] + args)
151+
_assert_eq(return_code, 0)

0 commit comments

Comments
 (0)