|
25 | 25 | import warnings
|
26 | 26 | from typing import Any
|
27 | 27 |
|
| 28 | +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first |
28 | 29 | from stable_baselines3.common.utils import constant_fn
|
29 | 30 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn
|
30 | 31 |
|
@@ -156,17 +157,8 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, fast_variant: bool = Tr
|
156 | 157 | self.num_envs = self.unwrapped.num_envs
|
157 | 158 | self.sim_device = self.unwrapped.device
|
158 | 159 | self.render_mode = self.unwrapped.render_mode
|
159 |
| - |
160 |
| - # obtain gym spaces |
161 |
| - # note: stable-baselines3 does not like when we have unbounded action space so |
162 |
| - # we set it to some high value here. Maybe this is not general but something to think about. |
163 |
| - observation_space = self.unwrapped.single_observation_space["policy"] |
164 |
| - action_space = self.unwrapped.single_action_space |
165 |
| - if isinstance(action_space, gym.spaces.Box) and not action_space.is_bounded("both"): |
166 |
| - action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape) |
167 |
| - |
168 |
| - # initialize vec-env |
169 |
| - VecEnv.__init__(self, self.num_envs, observation_space, action_space) |
| 160 | + self.observation_processors = {} |
| 161 | + self._process_spaces() |
170 | 162 | # add buffer for logging episodic information
|
171 | 163 | self._ep_rew_buf = np.zeros(self.num_envs)
|
172 | 164 | self._ep_len_buf = np.zeros(self.num_envs)
|
@@ -303,14 +295,67 @@ def get_images(self): # noqa: D102
|
303 | 295 | Helper functions.
|
304 | 296 | """
|
305 | 297 |
|
| 298 | + def _process_spaces(self): |
| 299 | + # process observation space |
| 300 | + observation_space = self.unwrapped.single_observation_space["policy"] |
| 301 | + if isinstance(observation_space, gym.spaces.Dict): |
| 302 | + for obs_key, obs_space in observation_space.spaces.items(): |
| 303 | + processors: list[callable[[torch.Tensor], Any]] = [] |
| 304 | + # assume normalized, if not, it won't pass is_image_space, which check [0-255]. |
| 305 | + # for scale like image space that has right shape but not scaled, we will scale it later |
| 306 | + if is_image_space(obs_space, check_channels=True, normalized_image=True): |
| 307 | + actually_normalized = np.all(obs_space.low == -1.0) and np.all(obs_space.high == 1.0) |
| 308 | + if not actually_normalized: |
| 309 | + if np.any(obs_space.low != 0) or np.any(obs_space.high != 255): |
| 310 | + raise ValueError( |
| 311 | + "Your image observation is not normalized in environment, and will not be" |
| 312 | + "normalized by sb3 if its min is not 0 and max is not 255." |
| 313 | + ) |
| 314 | + # sb3 will handle normalization and transpose, but sb3 expects uint8 images |
| 315 | + if obs_space.dtype != np.uint8: |
| 316 | + processors.append(lambda obs: obs.to(torch.uint8)) |
| 317 | + observation_space.spaces[obs_key] = gym.spaces.Box(0, 255, obs_space.shape, np.uint8) |
| 318 | + else: |
| 319 | + # sb3 will NOT handle the normalization, while sb3 will transpose, its transpose applies to all |
| 320 | + # image terms and maybe non-ideal, more, if we can do it in torch on gpu, it will be faster then |
| 321 | + # sb3 transpose it in numpy with cpu. |
| 322 | + if not is_image_space_channels_first(obs_space): |
| 323 | + |
| 324 | + def tranp(img: torch.Tensor) -> torch.Tensor: |
| 325 | + return img.permute(2, 0, 1) if len(img.shape) == 3 else img.permute(0, 3, 1, 2) |
| 326 | + |
| 327 | + processors.append(tranp) |
| 328 | + h, w, c = obs_space.shape |
| 329 | + observation_space.spaces[obs_key] = gym.spaces.Box(-1.0, 1.0, (c, h, w), obs_space.dtype) |
| 330 | + |
| 331 | + def chained_processor(obs: torch.Tensor, procs=processors) -> Any: |
| 332 | + for proc in procs: |
| 333 | + obs = proc(obs) |
| 334 | + return obs |
| 335 | + # add processor to the dictionary |
| 336 | + if len(processors) > 0: |
| 337 | + self.observation_processors[obs_key] = chained_processor |
| 338 | + |
| 339 | + # obtain gym spaces |
| 340 | + # note: stable-baselines3 does not like when we have unbounded action space so |
| 341 | + # we set it to some high value here. Maybe this is not general but something to think about. |
| 342 | + action_space = self.unwrapped.single_action_space |
| 343 | + if isinstance(action_space, gym.spaces.Box) and not action_space.is_bounded("both"): |
| 344 | + action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape) |
| 345 | + |
| 346 | + # initialize vec-env |
| 347 | + VecEnv.__init__(self, self.num_envs, observation_space, action_space) |
| 348 | + |
306 | 349 | def _process_obs(self, obs_dict: torch.Tensor | dict[str, torch.Tensor]) -> np.ndarray | dict[str, np.ndarray]:
|
307 | 350 | """Convert observations into NumPy data type."""
|
308 | 351 | # Sb3 doesn't support asymmetric observation spaces, so we only use "policy"
|
309 | 352 | obs = obs_dict["policy"]
|
310 | 353 | # note: ManagerBasedRLEnv uses torch backend (by default).
|
311 | 354 | if isinstance(obs, dict):
|
312 | 355 | for key, value in obs.items():
|
313 |
| - obs[key] = value.detach().cpu().numpy() |
| 356 | + if key in self.observation_processors: |
| 357 | + obs[key] = self.observation_processors[key](value) |
| 358 | + obs[key] = obs[key].detach().cpu().numpy() |
314 | 359 | elif isinstance(obs, torch.Tensor):
|
315 | 360 | obs = obs.detach().cpu().numpy()
|
316 | 361 | else:
|
|
0 commit comments