Skip to content

Commit 6300cc6

Browse files
authored
Merge pull request #202 from huangshiyu13/main
update
2 parents 9623d5c + 1f8c3ef commit 6300cc6

File tree

10 files changed

+45
-69
lines changed

10 files changed

+45
-69
lines changed

Gallery.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Users are also welcome to contribute their own training examples and demos to th
6363
| [Chat Bot](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)<br> <img width="300px" height="auto" src="./docs/images/chat.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![NLP](https://img.shields.io/badge/-NLP-green) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/nlp/) |
6464
| [Atari Pong](https://gymnasium.farama.org/environments/atari/pong/)<br> <img width="300px" height="auto" src="./docs/images/pong.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/atari/) |
6565
| [PettingZoo: Tic-Tac-Toe](https://pettingzoo.farama.org/environments/classic/tictactoe/)<br> <img width="300px" height="auto" src="./docs/images/tic-tac-toe.jpeg"> | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
66+
| [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)<br> <img width="300px" height="auto" src="https://shimmy.farama.org/_images/dm_locomotion.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/dm_control/) |
6667
| [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/34286328/171454189-6afafbff-bb61-4aac-b518-24646007cb9f.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/isaac/) |
6768
| [GridWorld](./examples/gridworld/)<br> <img width="300px" height="auto" src="./docs/images/gridworld.jpg"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) |
6869
| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/2184469/40948820-3d15e5c2-6830-11e8-81d4-ecfaffee0a14.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) |

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Environments currently supported by OpenRL (for more details, please refer to [G
104104
- [Atari](https://gymnasium.farama.org/environments/atari/)
105105
- [StarCraft II](https://github.com/oxwhirl/smac)
106106
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
107+
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
107108
- [GridWorld](./examples/gridworld/)
108109
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
109110
- [Gym Retro](https://github.com/openai/retro)

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
8686
- [Atari](https://gymnasium.farama.org/environments/atari/)
8787
- [StarCraft II](https://github.com/oxwhirl/smac)
8888
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
89+
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
8990
- [GridWorld](./examples/gridworld/)
9091
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
9192
- [Gym Retro](https://github.com/openai/retro)

examples/behavior_cloning/test_env.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ def test_env():
1010
cfg_parser = create_config_parser()
1111
cfg = cfg_parser.parse_args()
1212

13-
# create environment, set environment parallelism to 9
14-
# env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True)
15-
env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=False)
13+
# create environment
14+
env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True)
1615

1716
for ep_index in range(10):
1817
done = False

examples/dm_control/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
## Installation
2+
```bash
3+
pip install shimmy[dm-control]
4+
```
5+
6+
## Usage
7+
```bash
8+
python train_ppo.py
9+
```
File renamed without changes.

examples/dmc/train_ppo.py renamed to examples/dm_control/train_ppo.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,7 @@
77
from openrl.envs.wrappers.extra_wrappers import GIFWrapper
88
from openrl.modules.common import PPONet as Net
99
from openrl.runners.common import PPOAgent as Agent
10-
11-
12-
class FrameSkip(BaseWrapper):
13-
def __init__(self, env, num_frames: int = 8):
14-
super().__init__(env)
15-
self.num_frames = num_frames
16-
17-
def step(self, action):
18-
num_skips = self.num_frames
19-
total_reward = 0.0
20-
21-
for x in range(num_skips):
22-
obs, rew, term, trunc, info = super().step(action)
23-
total_reward += rew
24-
if term or trunc:
25-
break
26-
27-
return obs, total_reward, term, trunc, info
28-
10+
from openrl.envs.wrappers.extra_wrappers import FrameSkip
2911

3012
env_name = "dm_control/cartpole-balance-v0"
3113
# env_name = "dm_control/walker-walk-v0"
@@ -36,7 +18,7 @@ def train():
3618
cfg_parser = create_config_parser()
3719
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
3820

39-
# create environment, set environment parallelism to 9
21+
# create environment, set environment parallelism to 64
4022
env = make(
4123
env_name,
4224
env_num=64,
@@ -50,35 +32,30 @@ def train():
5032
agent = Agent(
5133
net,
5234
)
53-
# start training, set total number of training steps to 20000
35+
# start training, set total number of training steps to 100000
5436
agent.train(total_time_steps=100000)
5537
agent.save("./ppo_agent")
5638
env.close()
5739
return agent
5840

5941

60-
61-
62-
6342
def evaluation():
6443
cfg_parser = create_config_parser()
6544
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
6645
# begin to test
67-
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
68-
render_mode = "group_human"
46+
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
6947
render_mode = "group_rgb_array"
7048
env = make(
7149
env_name,
7250
render_mode=render_mode,
7351
env_num=4,
7452
asynchronous=True,
75-
env_wrappers=[FrameSkip,FlattenObservation],
76-
cfg=cfg
53+
env_wrappers=[FrameSkip, FlattenObservation],
54+
cfg=cfg,
7755
)
56+
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
7857
env = GIFWrapper(env, gif_path="./new.gif", fps=5)
7958

80-
81-
8259
net = Net(env, cfg=cfg, device="cuda")
8360
# initialize the trainer
8461
agent = Agent(
@@ -103,8 +80,10 @@ def evaluation():
10380
total_reward += np.mean(r)
10481
if step % 50 == 0:
10582
print(f"{step}: reward:{np.mean(r)}")
106-
print("total step:", step, total_reward)
83+
print("total step:", step, "total reward:", total_reward)
10784
env.close()
10885

109-
train()
110-
evaluation()
86+
87+
if __name__ == "__main__":
88+
train()
89+
evaluation()

openrl/envs/dmc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def make_dmc_envs(
1313
render_mode: Optional[Union[str, List[str]]] = None,
1414
**kwargs,
1515
):
16-
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
16+
from openrl.envs.wrappers import (
1717
RemoveTruncated,
1818
Single2MultiAgentWrapper,
1919
)

openrl/envs/dmc/dmc_env.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,13 @@
11
from typing import Any, Optional
22

3-
import dmc2gym
43
import gymnasium as gym
54
import numpy as np
65

7-
# class DmcEnv:
8-
# def __init__(self):
9-
# env = dmc2gym.make(
10-
# domain_name='walker',
11-
# task_name='walk',
12-
# seed=42,
13-
# visualize_reward=False,
14-
# from_pixels='features',
15-
# height=224,
16-
# width=224,
17-
# frame_skip=2
18-
# )
19-
# # self.observation_space = spaces.Box(
20-
# # low=np.array([0, 0, 0, 0]),
21-
# # high=np.array([self.nrow - 1, self.ncol - 1, self.nrow - 1, self.ncol - 1]),
22-
# # dtype=int,
23-
# # ) # current position and target position
24-
# # self.action_space = spaces.Discrete(
25-
# # 5
26-
# # )
27-
286

297
def make(
308
id: str,
319
render_mode: Optional[str] = None,
3210
**kwargs: Any,
3311
):
3412
env = gym.make(id, render_mode=render_mode)
35-
# env = dmc2gym.make(
36-
# domain_name='walker',
37-
# task_name='walk',
38-
# seed=42,
39-
# visualize_reward=False,
40-
# from_pixels='features',
41-
# height=224,
42-
# width=224,
43-
# frame_skip=2
44-
# )
4513
return env

openrl/envs/wrappers/extra_wrappers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@
2828
from openrl.envs.wrappers.flatten import flatten
2929

3030

31+
class FrameSkip(BaseWrapper):
32+
def __init__(self, env, num_frames: int = 8):
33+
super().__init__(env)
34+
self.num_frames = num_frames
35+
36+
def step(self, action):
37+
num_skips = self.num_frames
38+
total_reward = 0.0
39+
40+
for x in range(num_skips):
41+
obs, rew, term, trunc, info = super().step(action)
42+
total_reward += rew
43+
if term or trunc:
44+
break
45+
46+
return obs, total_reward, term, trunc, info
47+
48+
3149
class RemoveTruncated(StepAPICompatibility, BaseWrapper):
3250
def __init__(
3351
self,

0 commit comments

Comments
 (0)