Skip to content

Commit 65f50d9

Browse files
committed
add snake
1 parent 1f8c3ef commit 65f50d9

File tree

13 files changed

+2134
-1
lines changed

13 files changed

+2134
-1
lines changed

examples/smac/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@ Installation guide for Linux:
1111

1212
Train SMAC with [MAPPO](https://arxiv.org/abs/2103.01955) algorithm:
1313

14-
`python train_ppo.py --config smac_ppo.yaml`
14+
`python train_ppo.py --config smac_ppo.yaml`
15+
16+
## Render replay on Mac
17+

examples/snake/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
This is the example for the snake game.
3+
4+
5+
## Submit to JiDi
6+
7+
Submition site: http://www.jidiai.cn/env_detail?envid=1.
8+
9+
Snake senarios: [here](https://github.com/jidiai/ai_lib/blob/7a6986f0cb543994277103dbf605e9575d59edd6/env/config.json#L94)
10+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding:utf-8 -*-
2+
def sample_single_dim(action_space_list_each, is_act_continuous):
3+
if is_act_continuous:
4+
each = action_space_list_each.sample()
5+
else:
6+
if action_space_list_each.__class__.__name__ == "Discrete":
7+
each = [0] * action_space_list_each.n
8+
idx = action_space_list_each.sample()
9+
each[idx] = 1
10+
elif action_space_list_each.__class__.__name__ == "MultiDiscreteParticle":
11+
each = []
12+
nvec = action_space_list_each.high - action_space_list_each.low + 1
13+
sample_indexes = action_space_list_each.sample()
14+
15+
for i in range(len(nvec)):
16+
dim = nvec[i]
17+
new_action = [0] * dim
18+
index = sample_indexes[i]
19+
new_action[index] = 1
20+
each.extend(new_action)
21+
return each
22+
23+
24+
def my_controller(observation, action_space, is_act_continuous):
25+
joint_action = []
26+
for i in range(len(action_space)):
27+
player = sample_single_dim(action_space[i], is_act_continuous)
28+
joint_action.append(player)
29+
return joint_action
30+

examples/snake/test_env.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import numpy as np
19+
from openrl.envs.snake.snake import SnakeEatBeans
20+
21+
env = SnakeEatBeans()
22+
23+
obs, info = env.reset()
24+
25+
done = False
26+
while not np.any(done):
27+
a1 = np.zeros(4)
28+
a1[env.action_space.sample()] = 1
29+
a2 = np.zeros(4)
30+
a2[env.action_space.sample()] = 1
31+
obs, reward, done, info = env.step([a1, a2])
32+
print("obs:", obs, reward, "\ndone:", done, info)

openrl/envs/snake/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""

openrl/envs/snake/common.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import numpy as np
2+
import sys
3+
import os
4+
5+
class HiddenPrints:
6+
def __enter__(self):
7+
self._original_stdout = sys.stdout
8+
sys.stdout = open(os.devnull, 'w')
9+
10+
def __exit__(self, exc_type, exc_val, exc_tb):
11+
sys.stdout.close()
12+
sys.stdout = self._original_stdout
13+
14+
class Board:
15+
def __init__(self, board_height, board_width, snakes, beans_positions, teams):
16+
# print('create board, beans_position: ', beans_positions)
17+
self.height = board_height
18+
self.width = board_width
19+
self.snakes = snakes
20+
self.snakes_count = len(snakes)
21+
self.beans_positions = beans_positions
22+
self.blank_sign = -self.snakes_count
23+
self.bean_sign = -self.snakes_count + 1
24+
self.board = np.zeros((board_height, board_width), dtype=int) + self.blank_sign
25+
self.open = dict()
26+
for key, snake in self.snakes.items():
27+
self.open[key] = [snake.head] # state 0 open list, heads, ready to spread
28+
# see [A* Pathfinding (E01: algorithm explanation)](https://www.youtube.com/watch?v=-L-WgKMFuhE)
29+
for x, y in snake.pos:
30+
self.board[x][y] = key # obstacles, e.g. 0, 1, 2, 3, 4, 5
31+
# for x, y in beans_positions:
32+
# self.board[x][y] = self.bean_sign # beans
33+
34+
self.state = 0
35+
self.controversy = dict()
36+
self.teams = teams
37+
38+
# print('initial board')
39+
# print(self.board)
40+
41+
def step(self): # delay: prevent rear-end collision
42+
new_open = {key: [] for key in self.snakes.keys()}
43+
self.state += 1 # update state
44+
# if self.state > delay:
45+
# for key, snake in self.snakes.items(): # drop tail
46+
# if snake.len >= self.state:
47+
# self.board[snake.pos[-(self.state - delay)][0]][snake.pos[-(self.state - delay)][1]] \
48+
# = self.blank_sign
49+
for key, snake in self.snakes.items():
50+
if snake.len >= self.state:
51+
self.board[snake.pos[-self.state][0]][snake.pos[-self.state][1]] = self.blank_sign # drop tail
52+
for key, value in self.open.items(): # value: e.g. [[8, 3], [6, 3], [7, 4]]
53+
others_tail_pos = [self.snakes[_].pos[-self.state]
54+
if self.snakes[_].len >= self.state else []
55+
for _ in set(range(self.snakes_count)) - {key}]
56+
for x, y in value:
57+
# print('start to spread snake {} on grid ({}, {})'.format(key, x, y))
58+
for x_, y_ in [((x + 1) % self.height, y), # down
59+
((x - 1) % self.height, y), # up
60+
(x, (y + 1) % self.width), # right
61+
(x, (y - 1) % self.width)]: # left
62+
sign = self.board[x_][y_]
63+
idx = sign % self.snakes_count # which snake, e.g. 0, 1, 2, 3, 4, 5 / number of claims
64+
state = sign // self.snakes_count # manhattan distance to snake who claim the point or its negative
65+
if sign == self.blank_sign: # grid in initial state
66+
if [x_, y_] in others_tail_pos:
67+
# print('do not spread other snakes tail, in case of rear-end collision')
68+
continue # do not spread other snakes' tail, in case of rear-end collision
69+
self.board[x_][y_] = self.state * self.snakes_count + key
70+
self.snakes[key].claimed_count += 1
71+
new_open[key].append([x_, y_])
72+
73+
elif key != idx and self.state == state:
74+
# second claim, init controversy, change grid value from + to -
75+
# print(
76+
# '\tgird ({}, {}) in the same state claimed by different snakes '
77+
# 'with sign {}, idx {} and state {}'.format(
78+
# x_, y_, sign, idx, state))
79+
if self.snakes[idx].len > self.snakes[key].len: # shorter snake claim the controversial grid
80+
# print('\t\tsnake {} is shorter than snake {}'.format(key, idx))
81+
self.snakes[idx].claimed_count -= 1
82+
new_open[idx].remove([x_, y_])
83+
self.board[x_][y_] = self.state * self.snakes_count + key
84+
self.snakes[key].claimed_count += 1
85+
new_open[key].append([x_, y_])
86+
elif self.snakes[idx].len == self.snakes[key].len: # controversial claim
87+
# print(
88+
# '\t\tcontroversy! first claimed by snake {}, then claimed by snake {}'.format(idx, key))
89+
self.controversy[(x_, y_)] = {'state': self.state,
90+
'length': self.snakes[idx].len,
91+
'indexes': [idx, key]}
92+
# first claim by snake idx, then claim by snake key
93+
self.board[x_][y_] = -self.state * self.snakes_count + 1
94+
# if + 2, not enough for all snakes claim one grid!!
95+
self.snakes[idx].claimed_count -= 1 # controversy, no snake claim this grid!!
96+
new_open[key].append([x_, y_])
97+
else: # (self.snakes[idx].len < self.snakes[key].len)
98+
pass # longer snake do not claim the controversial grid
99+
100+
elif (x_, y_) in self.controversy \
101+
and key not in self.controversy[(x_, y_)]['indexes'] \
102+
and self.state + state == 0: # third claim or more
103+
# print('snake {} meets third or more claim in grid ({}, {})'.format(key, x_, y_))
104+
controversy = self.controversy[(x_, y_)]
105+
# pprint.pprint(controversy)
106+
if controversy['length'] > self.snakes[key].len: # shortest snake claim grid, do 4 things
107+
# print('\t\tsnake {} is shortest'.format(key))
108+
indexes_count = len(controversy['indexes'])
109+
for i in controversy['indexes']:
110+
self.snakes[i].claimed_count -= 1 / indexes_count # update claimed_count !
111+
new_open[i].remove([x_, y_])
112+
del self.controversy[(x_, y_)]
113+
self.board[x_][y_] = self.state * self.snakes_count + key
114+
self.snakes[key].claimed_count += 1
115+
new_open[key].append([x_, y_])
116+
elif controversy['length'] == self.snakes[key].len: # controversial claim
117+
# print('\t\tcontroversy! multi claimed by snake {}'.format(key))
118+
self.controversy[(x_, y_)]['indexes'].append(key)
119+
self.board[x_][y_] += 1
120+
new_open[key].append([x_, y_])
121+
else: # (controversy['length'] < self.snakes[key].len)
122+
pass # longer snake do not claim the controversial grid
123+
else:
124+
pass # do nothing with lower state grids
125+
126+
self.open = new_open # update open
127+
# update controversial snakes' claimed_count (in fraction) in the end
128+
for _, d in self.controversy.items():
129+
controversial_snake_count = len(d['indexes']) # number of controversial snakes
130+
for idx in d['indexes']:
131+
self.snakes[idx].claimed_count += 1 / controversial_snake_count
132+
133+
134+
class SnakePos:
135+
def __init__(self, snake_positions, board_height, board_width, beans_positions):
136+
self.pos = snake_positions # [[2, 9], [2, 8], [2, 7]]
137+
self.len = len(snake_positions) # >= 3
138+
self.head = snake_positions[0]
139+
self.beans_positions = beans_positions
140+
self.claimed_count = 0
141+
142+
displace = [(self.head[0] - snake_positions[1][0]) % board_height,
143+
(self.head[1] - snake_positions[1][1]) % board_width]
144+
# print('creat snake, pos: ', self.pos, 'displace:', displace)
145+
if displace == [board_height - 1, 0]: # all action are ordered by left, up, right, relative to the body
146+
self.dir = 0 # up
147+
self.legal_action = [2, 0, 3]
148+
elif displace == [1, 0]:
149+
self.dir = 1 # down
150+
self.legal_action = [3, 1, 2]
151+
elif displace == [0, board_width - 1]:
152+
self.dir = 2 # left
153+
self.legal_action = [1, 2, 0]
154+
elif displace == [0, 1]:
155+
self.dir = 3 # right
156+
self.legal_action = [0, 3, 1]
157+
else:
158+
assert False, 'snake positions error'
159+
positions = [[(self.head[0] - 1) % board_height, self.head[1]],
160+
[(self.head[0] + 1) % board_height, self.head[1]],
161+
[self.head[0], (self.head[1] - 1) % board_width],
162+
[self.head[0], (self.head[1] + 1) % board_width]]
163+
self.legal_position = [positions[_] for _ in self.legal_action]
164+
165+
def get_action(self, position):
166+
if position not in self.legal_position:
167+
assert False, 'the start and end points do not match'
168+
idx = self.legal_position.index(position)
169+
return self.legal_action[idx] # 0, 1, 2, 3: up, down, left, right
170+
171+
def step(self, legal_input):
172+
if legal_input in self.legal_position:
173+
position = legal_input
174+
elif legal_input in self.legal_action:
175+
idx = self.legal_action.index(legal_input)
176+
position = self.legal_position[idx]
177+
else:
178+
assert False, 'illegal snake move'
179+
self.head = position
180+
self.pos.insert(0, position)
181+
if position in self.beans_positions: # eat a bean
182+
self.len += 1
183+
else: # do not eat a bean
184+
self.pos.pop()

openrl/envs/snake/discrete.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
from .space import Space
3+
4+
5+
class Discrete(Space):
6+
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
7+
Example::
8+
>>> Discrete(2)
9+
"""
10+
def __init__(self, n):
11+
assert n >= 0
12+
self.n = n
13+
super(Discrete, self).__init__((), np.int64)
14+
15+
def sample(self):
16+
return self.np_random.randint(self.n)
17+
18+
def contains(self, x):
19+
if isinstance(x, int):
20+
as_int = x
21+
elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes['AllInteger'] and x.shape == ()):
22+
as_int = int(x)
23+
else:
24+
return False
25+
return as_int >= 0 and as_int < self.n
26+
27+
def __repr__(self):
28+
return "Discrete(%d)" % self.n
29+
30+
def __eq__(self, other):
31+
return isinstance(other, Discrete) and self.n == other.n

openrl/envs/snake/game.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# -*- coding:utf-8 -*-
2+
# 作者:zruizhi
3+
# 创建时间: 2020/7/10 10:24 上午
4+
# 描述:
5+
from abc import ABC, abstractmethod
6+
7+
8+
class Game(ABC):
9+
def __init__(self, n_player, is_obs_continuous, is_act_continuous, game_name, agent_nums, obs_type):
10+
self.n_player = n_player
11+
self.current_state = None
12+
self.all_observes = None
13+
self.is_obs_continuous = is_obs_continuous
14+
self.is_act_continuous = is_act_continuous
15+
self.game_name = game_name
16+
self.agent_nums = agent_nums
17+
self.obs_type = obs_type
18+
19+
def get_config(self, player_id):
20+
raise NotImplementedError
21+
22+
def get_render_data(self, current_state):
23+
return current_state
24+
25+
def set_current_state(self, current_state):
26+
raise NotImplementedError
27+
28+
@abstractmethod
29+
def is_terminal(self):
30+
raise NotImplementedError
31+
32+
def get_next_state(self, all_action):
33+
raise NotImplementedError
34+
35+
def get_reward(self, all_action):
36+
raise NotImplementedError
37+
38+
@abstractmethod
39+
def step(self, all_action):
40+
raise NotImplementedError
41+
42+
@abstractmethod
43+
def reset(self):
44+
raise NotImplementedError
45+
46+
def set_action_space(self):
47+
raise NotImplementedError

0 commit comments

Comments
 (0)