Skip to content

Commit e864a08

Browse files
authored
Merge pull request #278 from kingjuno/Issue-#216
Issue #216: Add envpool to openrl
2 parents 8185373 + 693d2e1 commit e864a08

File tree

6 files changed

+424
-2
lines changed

6 files changed

+424
-2
lines changed

examples/envpool/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Installation
2+
3+
4+
Install envpool with:
5+
6+
``` shell
7+
pip install envpool
8+
```
9+
10+
Note 1: envpool only supports Linux operating system.
11+
12+
## Usage
13+
14+
You can use `OpenRL` to train Cartpole (envpool) via:
15+
16+
``` shell
17+
PYTHON_PATH train_ppo.py
18+
```
19+
20+
You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers.

examples/envpool/envpool_wrappers.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import time
2+
import warnings
3+
from typing import Optional
4+
5+
import gym
6+
import gymnasium
7+
import numpy as np
8+
from envpool.python.protocol import EnvPool
9+
from packaging import version
10+
from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper
11+
from stable_baselines3.common.vec_env import VecMonitor
12+
from stable_baselines3.common.vec_env.base_vec_env import (VecEnvObs,
13+
VecEnvStepReturn)
14+
15+
is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")
16+
17+
18+
class VecEnvWrapper(BaseWrapper):
19+
@property
20+
def agent_num(self):
21+
if self.is_original_envpool_env():
22+
return 1
23+
else:
24+
return self.env.agent_num
25+
26+
def is_original_envpool_env(self):
27+
return not hasattr(self.venv, "agent_num`")
28+
29+
30+
class VecAdapter(VecEnvWrapper):
31+
"""
32+
Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.
33+
34+
:param venv: The envpool object.
35+
"""
36+
37+
def __init__(self, venv: EnvPool):
38+
venv.num_envs = venv.spec.config.num_envs
39+
observation_space = venv.observation_space
40+
new_observation_space = gymnasium.spaces.Box(
41+
low=observation_space.low,
42+
high=observation_space.high,
43+
dtype=observation_space.dtype,
44+
)
45+
action_space = venv.action_space
46+
if isinstance(action_space, gym.spaces.Discrete):
47+
new_action_space = gymnasium.spaces.Discrete(action_space.n)
48+
elif isinstance(action_space, gym.spaces.MultiDiscrete):
49+
new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec)
50+
elif isinstance(action_space, gym.spaces.MultiBinary):
51+
new_action_space = gymnasium.spaces.MultiBinary(action_space.n)
52+
elif isinstance(action_space, gym.spaces.Box):
53+
new_action_space = gymnasium.spaces.Box(
54+
low=action_space.low,
55+
high=action_space.high,
56+
dtype=action_space.dtype,
57+
)
58+
else:
59+
raise NotImplementedError(f"Action space {action_space} is not supported")
60+
super().__init__(
61+
venv=venv,
62+
observation_space=new_observation_space,
63+
action_space=new_action_space,
64+
)
65+
66+
def step_async(self, actions: np.ndarray) -> None:
67+
self.actions = actions
68+
69+
def reset(self) -> VecEnvObs:
70+
if is_legacy_gym:
71+
return self.venv.reset(), {}
72+
else:
73+
return self.venv.reset()
74+
75+
def step_wait(self) -> VecEnvStepReturn:
76+
if is_legacy_gym:
77+
obs, rewards, dones, info_dict = self.venv.step(self.actions)
78+
else:
79+
obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions)
80+
dones = terms + truncs
81+
rewards = rewards
82+
infos = []
83+
for i in range(self.num_envs):
84+
infos.append(
85+
{
86+
key: info_dict[key][i]
87+
for key in info_dict.keys()
88+
if isinstance(info_dict[key], np.ndarray)
89+
}
90+
)
91+
if dones[i]:
92+
infos[i]["terminal_observation"] = obs[i]
93+
if is_legacy_gym:
94+
obs[i] = self.venv.reset(np.array([i]))
95+
else:
96+
obs[i] = self.venv.reset(np.array([i]))[0]
97+
return obs, rewards, dones, infos
98+
99+
100+
class VecMonitor(VecEnvWrapper):
101+
def __init__(
102+
self,
103+
venv,
104+
filename: Optional[str] = None,
105+
info_keywords=(),
106+
):
107+
# Avoid circular import
108+
from stable_baselines3.common.monitor import Monitor, ResultsWriter
109+
110+
try:
111+
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
112+
except AttributeError:
113+
is_wrapped_with_monitor = False
114+
115+
if is_wrapped_with_monitor:
116+
warnings.warn(
117+
"The environment is already wrapped with a `Monitor` wrapper"
118+
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
119+
"overwritten by the `VecMonitor` ones.",
120+
UserWarning,
121+
)
122+
123+
VecEnvWrapper.__init__(self, venv)
124+
self.episode_count = 0
125+
self.t_start = time.time()
126+
127+
env_id = None
128+
if hasattr(venv, "spec") and venv.spec is not None:
129+
env_id = venv.spec.id
130+
131+
self.results_writer: Optional[ResultsWriter] = None
132+
if filename:
133+
self.results_writer = ResultsWriter(
134+
filename,
135+
header={"t_start": self.t_start, "env_id": str(env_id)},
136+
extra_keys=info_keywords,
137+
)
138+
139+
self.info_keywords = info_keywords
140+
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
141+
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
142+
143+
def reset(self, **kwargs) -> VecEnvObs:
144+
obs, info = self.venv.reset()
145+
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
146+
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
147+
return obs, info
148+
149+
def step_wait(self) -> VecEnvStepReturn:
150+
obs, rewards, dones, infos = self.venv.step_wait()
151+
self.episode_returns += rewards
152+
self.episode_lengths += 1
153+
new_infos = list(infos[:])
154+
for i in range(len(dones)):
155+
if dones[i]:
156+
info = infos[i].copy()
157+
episode_return = self.episode_returns[i]
158+
episode_length = self.episode_lengths[i]
159+
episode_info = {
160+
"r": episode_return,
161+
"l": episode_length,
162+
"t": round(time.time() - self.t_start, 6),
163+
}
164+
for key in self.info_keywords:
165+
episode_info[key] = info[key]
166+
info["episode"] = episode_info
167+
self.episode_count += 1
168+
self.episode_returns[i] = 0
169+
self.episode_lengths[i] = 0
170+
if self.results_writer:
171+
self.results_writer.write_row(episode_info)
172+
new_infos[i] = info
173+
rewards = np.expand_dims(rewards, 1)
174+
return obs, rewards, dones, new_infos
175+
176+
def close(self) -> None:
177+
if self.results_writer:
178+
self.results_writer.close()
179+
return self.venv.close()
180+
181+
182+
__all__ = ["VecAdapter", "VecMonitor"]

examples/envpool/make_env.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import copy
2+
import inspect
3+
from typing import Callable, Iterable, List, Optional, Union
4+
5+
import envpool
6+
from gymnasium import Env
7+
8+
9+
from openrl.envs.vec_env import (AsyncVectorEnv, RewardWrapper,
10+
SyncVectorEnv, VecMonitorWrapper)
11+
from openrl.envs.vec_env.vec_info import VecInfoFactory
12+
from openrl.envs.wrappers.base_wrapper import BaseWrapper
13+
from openrl.rewards import RewardFactory
14+
15+
16+
def build_envs(
17+
make,
18+
id: str,
19+
env_num: int = 1,
20+
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
21+
need_env_id: bool = False,
22+
**kwargs,
23+
) -> List[Callable[[], Env]]:
24+
cfg = kwargs.get("cfg", None)
25+
26+
def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]:
27+
def _make_env() -> Env:
28+
new_kwargs = copy.deepcopy(kwargs)
29+
if need_env_id:
30+
new_kwargs["env_id"] = env_id
31+
new_kwargs["env_num"] = env_num
32+
if "envpool" in new_kwargs:
33+
# for now envpool doesnt support any render mode
34+
# envpool also doesnt stores the id anywhere
35+
new_kwargs.pop("envpool")
36+
env = make(
37+
id,
38+
**new_kwargs,
39+
)
40+
env.unwrapped.spec.id = id
41+
42+
if wrappers is not None:
43+
if callable(wrappers):
44+
if issubclass(wrappers, BaseWrapper):
45+
env = wrappers(env, cfg=cfg)
46+
else:
47+
env = wrappers(env)
48+
elif isinstance(wrappers, Iterable) and all(
49+
[callable(w) for w in wrappers]
50+
):
51+
for wrapper in wrappers:
52+
if (
53+
issubclass(wrapper, BaseWrapper)
54+
and "cfg" in inspect.signature(wrapper.__init__).parameters
55+
):
56+
env = wrapper(env, cfg=cfg)
57+
else:
58+
env = wrapper(env)
59+
else:
60+
raise NotImplementedError
61+
62+
return env
63+
64+
return _make_env
65+
66+
env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)]
67+
return env_fns
68+
69+
70+
def make_envpool_envs(
71+
id: str,
72+
env_num: int = 1,
73+
**kwargs,
74+
):
75+
assert "env_type" in kwargs
76+
assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"]
77+
kwargs["envpool"] = True
78+
79+
if 'env_wrappers' in kwargs:
80+
env_wrappers = kwargs.pop("env_wrappers")
81+
else:
82+
env_wrappers = []
83+
env_fns = build_envs(
84+
make=envpool.make,
85+
id=id,
86+
env_num=env_num,
87+
wrappers=env_wrappers,
88+
**kwargs,
89+
)
90+
return env_fns
91+
92+
93+
def make(
94+
id: str,
95+
env_num: int = 1,
96+
asynchronous: bool = False,
97+
add_monitor: bool = True,
98+
render_mode: Optional[str] = None,
99+
auto_reset: bool = True,
100+
**kwargs,
101+
):
102+
cfg = kwargs.get("cfg", None)
103+
if id in envpool.registration.list_all_envs():
104+
env_fns = make_envpool_envs(
105+
id=id.split(":")[-1],
106+
env_num=env_num,
107+
**kwargs,
108+
)
109+
if asynchronous:
110+
env = AsyncVectorEnv(
111+
env_fns, render_mode=render_mode, auto_reset=auto_reset
112+
)
113+
else:
114+
env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset)
115+
116+
reward_class = cfg.reward_class if cfg else None
117+
reward_class = RewardFactory.get_reward_class(reward_class, env)
118+
119+
env = RewardWrapper(env, reward_class)
120+
121+
if add_monitor:
122+
vec_info_class = cfg.vec_info_class if cfg else None
123+
vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env)
124+
env = VecMonitorWrapper(vec_info_class, env)
125+
126+
return env
127+
else:
128+
raise NotImplementedError(f"env {id} is not supported")

0 commit comments

Comments
 (0)