Skip to content

Commit 401569b

Browse files
committed
add snake environment
1 parent 65f50d9 commit 401569b

35 files changed

+1040
-324
lines changed

Gallery.md

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,19 @@ Users are also welcome to contribute their own training examples and demos to th
5454

5555
<div align="center">
5656

57-
| Environment/Demo | Tags | Refs |
58-
|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------:|:-------------------------------:|
59-
| [MuJoCo](https://github.com/deepmind/mujoco)<br> <img width="300px" height="auto" src="./docs/images/mujoco.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/mujoco/) |
60-
| [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/)<br> <img width="300px" height="auto" src="./docs/images/cartpole.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
61-
| [MPE: Simple Spread](https://pettingzoo.farama.org/environments/mpe/simple_spread/)<br> <img width="300px" height="auto" src="./docs/images/simple_spread_trained.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
62-
| [StarCraft II](https://github.com/oxwhirl/smac)<br> <img width="300px" height="auto" src="./docs/images/smac.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/smac/) |
63-
| [Chat Bot](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)<br> <img width="300px" height="auto" src="./docs/images/chat.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![NLP](https://img.shields.io/badge/-NLP-green) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/nlp/) |
64-
| [Atari Pong](https://gymnasium.farama.org/environments/atari/pong/)<br> <img width="300px" height="auto" src="./docs/images/pong.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/atari/) |
65-
| [PettingZoo: Tic-Tac-Toe](https://pettingzoo.farama.org/environments/classic/tictactoe/)<br> <img width="300px" height="auto" src="./docs/images/tic-tac-toe.jpeg"> | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
66-
| [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)<br> <img width="300px" height="auto" src="https://shimmy.farama.org/_images/dm_locomotion.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/dm_control/) |
67-
| [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/34286328/171454189-6afafbff-bb61-4aac-b518-24646007cb9f.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/isaac/) |
68-
| [GridWorld](./examples/gridworld/)<br> <img width="300px" height="auto" src="./docs/images/gridworld.jpg"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) |
69-
| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/2184469/40948820-3d15e5c2-6830-11e8-81d4-ecfaffee0a14.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) |
70-
| [Gym Retro](https://github.com/openai/retro)<br> <img width="300px" height="auto" src="./docs/images/gym-retro.jpg"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/retro/) |
57+
| Environment/Demo | Tags | Refs |
58+
|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-------------------------------:|
59+
| [MuJoCo](https://github.com/deepmind/mujoco)<br> <img width="300px" height="auto" src="./docs/images/mujoco.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/mujoco/) |
60+
| [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/)<br> <img width="300px" height="auto" src="./docs/images/cartpole.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
61+
| [MPE: Simple Spread](https://pettingzoo.farama.org/environments/mpe/simple_spread/)<br> <img width="300px" height="auto" src="./docs/images/simple_spread_trained.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
62+
| [StarCraft II](https://github.com/oxwhirl/smac)<br> <img width="300px" height="auto" src="./docs/images/smac.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/smac/) |
63+
| [Chat Bot](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)<br> <img width="300px" height="auto" src="./docs/images/chat.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![NLP](https://img.shields.io/badge/-NLP-green) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/nlp/) |
64+
| [Atari Pong](https://gymnasium.farama.org/environments/atari/pong/)<br> <img width="300px" height="auto" src="./docs/images/pong.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/atari/) |
65+
| [PettingZoo: Tic-Tac-Toe](https://pettingzoo.farama.org/environments/classic/tictactoe/)<br> <img width="300px" height="auto" src="./docs/images/tic-tac-toe.jpeg"> | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
66+
| [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)<br> <img width="300px" height="auto" src="https://shimmy.farama.org/_images/dm_locomotion.png"> | ![continuous](https://img.shields.io/badge/-continous-green) | [code](./examples/dm_control/) |
67+
| [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/34286328/171454189-6afafbff-bb61-4aac-b518-24646007cb9f.gif"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/isaac/) |
68+
| [Snake](http://www.jidiai.cn/env_detail?envid=1)<br> <img width="300px" height="auto" src="./docs/images/snakes_1v1.gif"> | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/snake/) |
69+
| [GridWorld](./examples/gridworld/)<br> <img width="300px" height="auto" src="./docs/images/gridworld.jpg"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) |
70+
| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)<br> <img width="300px" height="auto" src="https://user-images.githubusercontent.com/2184469/40948820-3d15e5c2-6830-11e8-81d4-ecfaffee0a14.png"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) |
71+
| [Gym Retro](https://github.com/openai/retro)<br> <img width="300px" height="auto" src="./docs/images/gym-retro.jpg"> | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/retro/) |
7172
</div>

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ Environments currently supported by OpenRL (for more details, please refer to [G
104104
- [Atari](https://gymnasium.farama.org/environments/atari/)
105105
- [StarCraft II](https://github.com/oxwhirl/smac)
106106
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
107-
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
107+
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
108+
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
108109
- [GridWorld](./examples/gridworld/)
109110
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
110111
- [Gym Retro](https://github.com/openai/retro)

README_zh.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
8686
- [Atari](https://gymnasium.farama.org/environments/atari/)
8787
- [StarCraft II](https://github.com/oxwhirl/smac)
8888
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
89-
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
89+
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
90+
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
9091
- [GridWorld](./examples/gridworld/)
9192
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
9293
- [Gym Retro](https://github.com/openai/retro)

docs/images/snakes_1v1.gif

108 KB
Loading

examples/dm_control/train_ppo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from openrl.configs.config import create_config_parser
55
from openrl.envs.common import make
66
from openrl.envs.wrappers.base_wrapper import BaseWrapper
7-
from openrl.envs.wrappers.extra_wrappers import GIFWrapper
7+
from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper
88
from openrl.modules.common import PPONet as Net
99
from openrl.runners.common import PPOAgent as Agent
10-
from openrl.envs.wrappers.extra_wrappers import FrameSkip
1110

1211
env_name = "dm_control/cartpole-balance-v0"
1312
# env_name = "dm_control/walker-walk-v0"

examples/snake/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11

22
This is the example for the snake game.
33

4+
## Usage
5+
6+
```bash
7+
python train_selfplay.py
8+
```
9+
410

511
## Submit to JiDi
612

713
Submition site: http://www.jidiai.cn/env_detail?envid=1.
814

915
Snake senarios: [here](https://github.com/jidiai/ai_lib/blob/7a6986f0cb543994277103dbf605e9575d59edd6/env/config.json#L94)
16+
Original Snake environment: [here](https://github.com/jidiai/ai_lib/blob/master/env/snakes.py)
1017

examples/snake/selfplay.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
seed: 0
2+
callbacks:
3+
- id: "ProgressBarCallback"

examples/snake/submissions/random_agent/submission.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,3 @@ def my_controller(observation, action_space, is_act_continuous):
2727
player = sample_single_dim(action_space[i], is_act_continuous)
2828
joint_action.append(player)
2929
return joint_action
30-

examples/snake/test_env.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,93 @@
1515
# limitations under the License.
1616

1717
""""""
18+
import time
19+
1820
import numpy as np
21+
from wrappers import ConvertObs
22+
1923
from openrl.envs.snake.snake import SnakeEatBeans
24+
from openrl.envs.snake.snake_pettingzoo import SnakeEatBeansAECEnv
25+
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper
26+
27+
28+
def test_raw_env():
29+
env = SnakeEatBeans()
30+
31+
obs, info = env.reset()
32+
33+
done = False
34+
while not np.any(done):
35+
a1 = np.zeros(4)
36+
a1[env.action_space.sample()] = 1
37+
a2 = np.zeros(4)
38+
a2[env.action_space.sample()] = 1
39+
obs, reward, done, info = env.step([a1, a2])
40+
print("obs:", obs)
41+
print("reward:", reward)
42+
print("done:", done)
43+
print("info:", info)
44+
45+
46+
def test_aec_env():
47+
from PIL import Image
48+
49+
img_list = []
50+
env = SnakeEatBeansAECEnv(render_mode="rgb_array")
51+
env.reset(seed=0)
52+
# time.sleep(1)
53+
img = env.render()
54+
img_list.append(img)
55+
step = 0
56+
for player_name in env.agent_iter():
57+
if step > 20:
58+
break
59+
observation, reward, termination, truncation, info = env.last()
60+
if termination or truncation:
61+
break
62+
action = env.action_space(player_name).sample()
63+
# if player_name == "player_0":
64+
# action = 2
65+
# elif player_name == "player_1":
66+
# action = 3
67+
# else:
68+
# raise ValueError("Unknown player name: {}".format(player_name))
69+
env.step(action)
70+
img = env.render()
71+
if player_name == "player_0":
72+
img_list.append(img)
73+
# time.sleep(1)
74+
75+
step += 1
76+
print("Total steps: {}".format(step))
77+
78+
save_path = "test.gif"
79+
img_list = [Image.fromarray(img) for img in img_list]
80+
img_list[0].save(save_path, save_all=True, append_images=img_list[1:], duration=500)
81+
82+
83+
def test_vec_env():
84+
from openrl.envs.common import make
2085

21-
env = SnakeEatBeans()
86+
env = make(
87+
"snakes_1v1",
88+
opponent_wrappers=[
89+
RandomOpponentWrapper,
90+
],
91+
env_wrappers=[ConvertObs],
92+
render_mode="group_human",
93+
env_num=2,
94+
)
95+
obs, info = env.reset()
96+
step = 0
97+
done = False
98+
while not np.any(done):
99+
action = env.random_action()
100+
obs, reward, done, info = env.step(action)
101+
time.sleep(0.3)
102+
step += 1
103+
print("Total steps: {}".format(step))
22104

23-
obs, info = env.reset()
24105

25-
done = False
26-
while not np.any(done):
27-
a1 = np.zeros(4)
28-
a1[env.action_space.sample()] = 1
29-
a2 = np.zeros(4)
30-
a2[env.action_space.sample()] = 1
31-
obs, reward, done, info = env.step([a1, a2])
32-
print("obs:", obs, reward, "\ndone:", done, info)
106+
if __name__ == "__main__":
107+
test_vec_env()

examples/snake/train_selfplay.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import torch
3+
from wrappers import ConvertObs
4+
5+
from openrl.configs.config import create_config_parser
6+
from openrl.envs.common import make
7+
from openrl.modules.common import PPONet as Net
8+
from openrl.runners.common import PPOAgent as Agent
9+
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper
10+
11+
12+
def train():
13+
cfg_parser = create_config_parser()
14+
cfg = cfg_parser.parse_args(["--config", "selfplay.yaml"])
15+
16+
# Create environment
17+
env_num = 10
18+
render_model = None
19+
env = make(
20+
"snakes_1v1",
21+
render_mode=render_model,
22+
env_num=env_num,
23+
asynchronous=True,
24+
opponent_wrappers=[RandomOpponentWrapper],
25+
env_wrappers=[ConvertObs],
26+
cfg=cfg,
27+
)
28+
# Create neural network
29+
30+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
31+
# Create agent
32+
agent = Agent(net)
33+
# Begin training
34+
agent.train(total_time_steps=100000)
35+
env.close()
36+
agent.save("./selfplay_agent/")
37+
return agent
38+
39+
40+
def evaluation():
41+
from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender
42+
43+
print("Evaluation...")
44+
env_num = 1
45+
env = make(
46+
"snakes_1v1",
47+
env_num=env_num,
48+
asynchronous=True,
49+
opponent_wrappers=[RandomOpponentWrapper],
50+
env_wrappers=[ConvertObs],
51+
auto_reset=False,
52+
)
53+
54+
cfg_parser = create_config_parser()
55+
cfg = cfg_parser.parse_args()
56+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
57+
58+
agent = Agent(net)
59+
60+
agent.load("./selfplay_agent/")
61+
agent.set_env(env)
62+
env.reset(seed=0)
63+
64+
total_reward = 0.0
65+
ep_num = 5
66+
for ep_now in range(ep_num):
67+
obs, info = env.reset()
68+
done = False
69+
step = 0
70+
71+
while not np.any(done):
72+
# predict next action based on the observation
73+
action, _ = agent.act(obs, info, deterministic=True)
74+
obs, r, done, info = env.step(action)
75+
step += 1
76+
77+
if np.any(done):
78+
total_reward += np.mean(r) > 0
79+
print(f"{ep_now}/{ep_num}: reward: {np.mean(r)}")
80+
print(f"win rate: {total_reward/ep_num}")
81+
env.close()
82+
print("Evaluation finished.")
83+
84+
85+
if __name__ == "__main__":
86+
train()
87+
evaluation()

0 commit comments

Comments
 (0)