Skip to content

Commit f7b8632

Browse files
authored
update bunch of miscellaneous things (#174)
* change RLBase.action_space to Base.OneTo(NUM_ACTIONS) * rename AbstractGridWorldGame to AbstractGridWorld * remove RLBaseEnvModule from benchmark.jl * rename get_tile_map_height to get_height & similarly for width * rename get_tile_pretty_repr & get_tile_map_pretty_repr to get_pretty_tile_map * rename get_sub_tile_map_pretty_repr to get_pretty_sub_tile_map * update printing for Catcher. Other envs broken temporarily * update printing for CollectGems envs * update miscellaneous methods for DoorKey envs * update miscellaneous methods for DynamicObstacles envs * update miscellaneous methods for GoToTarget envs * update miscellaneous methods for GridRooms envs * update miscellaneous methods for Maze envs * update miscellaneous methods for SequentialRooms envs * update miscellaneous methods for SingleRoom envs * update miscellaneous methods for Snake * update miscellaneous methods for Sokoban envs * update miscellaneous methods for Transport envs * update README
1 parent f05079c commit f7b8632

29 files changed

+858
-367
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ A package for creating grid world environments for reinforcement learning in Jul
44

55
This package is inspired by [gym-minigrid](https://github.com/maximecb/gym-minigrid). In order to cite this package, please refer to the file `CITATION.bib`. Starring the repository on GitHub is also appreciated. For benchmarks, refer to `benchmark/benchmarks.md`.
66

7-
**Important note:** This package is undergoing heavy internal redesign. This README reflects the new design. The README for the last released version (`0.4.0`) can be found [here](https://github.com/JuliaReinforcementLearning/GridWorlds.jl/tree/c0e86bb6c33819f0e4a4cefe0284d985d0474ed3).
7+
**Important note:** This package is undergoing heavy internal redesign. This README reflects the new design (some visualizations might be oudated). The README for the last released version (`0.4.0`) can be found [here](https://github.com/JuliaReinforcementLearning/GridWorlds.jl/tree/c0e86bb6c33819f0e4a4cefe0284d985d0474ed3).
88

99
## Table of contents:
1010

benchmark/benchmark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
function benchmark(Env, num_resets, steps_per_reset)
2323
benchmark = DS.OrderedDict()
2424

25-
env = GW.RLBaseEnvModule.RLBaseEnv(Env())
25+
env = GW.RLBaseEnv(Env())
2626

2727
benchmark[:random_policy] = BT.@benchmark run_random_policy!($(Ref(env))[], $(Ref(num_resets))[], $(Ref(steps_per_reset))[])
2828
benchmark[:reset] = BT.@benchmark RLBase.reset!($(Ref(env))[])

src/abstract_grid_world.jl

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,58 @@
1-
abstract type AbstractGridWorldGame end
1+
abstract type AbstractGridWorld end
22

33
#####
44
##### Game logic methods
55
#####
66

7-
reset!(env::AbstractGridWorldGame) = error("Method not implemented for $(typeof(env))")
8-
act!(env::AbstractGridWorldGame) = error("Method not implemented for $(typeof(env))")
7+
reset!(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
8+
act!(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
99

1010
#####
1111
##### Optional methods for pretty printing, playing, etc...
1212
#####
1313

14-
get_tile_pretty_repr(env::AbstractGridWorldGame, i::Integer, j::Integer) = error("Method not implemented for $(typeof(env))")
15-
get_sub_tile_map_pretty_repr(env::AbstractGridWorldGame, position::CartesianIndex{2}) = error("Method not implemented for $(typeof(env))")
16-
get_action_keys(env::AbstractGridWorldGame) = error("Method not implemented for $(typeof(env))")
17-
get_action_names(env::AbstractGridWorldGame) = error("Method not implemented for $(typeof(env))")
18-
get_tile_map_height(env::AbstractGridWorldGame) = error("Method not implemented for $(typeof(env))")
19-
get_tile_map_width(env::AbstractGridWorldGame) = error("Method not implemented for $(typeof(env))")
14+
get_pretty_tile_map(env::AbstractGridWorld, i::Integer, j::Integer) = error("Method not implemented for $(typeof(env))")
15+
get_pretty_sub_tile_map(env::AbstractGridWorld, position::CartesianIndex{2}) = error("Method not implemented for $(typeof(env))")
16+
get_action_keys(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
17+
get_action_names(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
18+
get_object_names(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
19+
get_height(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
20+
get_width(env::AbstractGridWorld) = error("Method not implemented for $(typeof(env))")
2021

21-
function get_tile_map_pretty_repr(env::AbstractGridWorldGame)
22-
height_tile_map = get_tile_map_height(env)
23-
width_tile_map = get_tile_map_width(env)
22+
function get_pretty_tile_map(env::AbstractGridWorld)
23+
height = get_height(env)
24+
width = get_width(env)
2425

2526
str = ""
2627

27-
for i in 1:height_tile_map
28-
for j in 1:width_tile_map
29-
str = str * get_tile_pretty_repr(env, i, j)
28+
for i in 1:height
29+
for j in 1:width
30+
str = str * get_pretty_tile_map(env, CartesianIndex(i, j))
3031
end
31-
if i < height_tile_map
32+
if i < height
3233
str = str * "\n"
3334
end
3435
end
3536

3637
return str
3738
end
3839

39-
function get_window_size(env::AbstractGridWorldGame)
40-
height = get_tile_map_height(env)
41-
width = get_tile_map_width(env)
40+
function get_window_size(env::AbstractGridWorld)
41+
height = get_height(env)
42+
width = get_width(env)
4243
return (2 * (height ÷ 4) + 1, 2 * (width ÷ 4) + 1)
4344
end
4445

45-
function get_sub_tile_map_pretty_repr(env::AbstractGridWorldGame, window_size)
46-
height_sub_tile_map, width_sub_tile_map = window_size
46+
function get_pretty_sub_tile_map(env::AbstractGridWorld, window_size)
47+
height, width = window_size
4748

4849
str = ""
4950

50-
for i in 1:height_sub_tile_map
51-
for j in 1:width_sub_tile_map
52-
str = str * get_sub_tile_map_pretty_repr(env, window_size, CartesianIndex(i, j))
51+
for i in 1:height
52+
for j in 1:width
53+
str = str * get_pretty_sub_tile_map(env, window_size, CartesianIndex(i, j))
5354
end
54-
if i < height_sub_tile_map
55+
if i < height
5556
str = str * "\n"
5657
end
5758
end

src/envs/catcher.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const AGENT = 1
1313
const GEM = 2
1414
const NUM_ACTIONS = 3
1515

16-
mutable struct Catcher{R, RNG} <: GW.AbstractGridWorldGame
16+
mutable struct Catcher{R, RNG} <: GW.AbstractGridWorld
1717
tile_map::BitArray{3}
1818
agent_position::CartesianIndex{2}
1919
reward::R
@@ -125,30 +125,35 @@ end
125125
##### miscellaneous
126126
#####
127127

128-
CHARACTERS = ('', '', '')
128+
GW.get_height(env::Catcher) = size(env.tile_map, 2)
129+
GW.get_width(env::Catcher) = size(env.tile_map, 3)
129130

130-
GW.get_tile_map_height(env::Catcher) = size(env.tile_map, 2)
131-
GW.get_tile_map_width(env::Catcher) = size(env.tile_map, 3)
131+
function GW.get_pretty_tile_map(env::Catcher, position::CartesianIndex{2})
132+
characters = ('', '', '')
132133

133-
function GW.get_tile_pretty_repr(env::Catcher, i::Integer, j::Integer)
134-
object = findfirst(@view env.tile_map[:, i, j])
134+
object = findfirst(@view env.tile_map[:, position])
135135
if isnothing(object)
136-
return CHARACTERS[end]
136+
return characters[end]
137137
else
138-
return CHARACTERS[object]
138+
return characters[object]
139139
end
140140
end
141141

142-
GW.get_action_keys(env::Catcher) = ('a', 'd', 's')
142+
GW.get_object_names(env::Catcher) = (:AGENT, :GEM)
143143
GW.get_action_names(env::Catcher) = (:MOVE_LEFT, :MOVE_RIGHT, :NO_MOVE)
144144

145145
function Base.show(io::IO, ::MIME"text/plain", env::Catcher)
146-
str = GW.get_tile_map_pretty_repr(env)
147-
str = str * "\nreward = $(env.reward)\ndone = $(env.done)"
146+
str = "tile_map:\n"
147+
str = str * GW.get_pretty_tile_map(env)
148+
str = str * "\nreward: $(env.reward)\ndone: $(env.done)"
149+
str = str * "\naction_names: $(GW.get_action_names(env))"
150+
str = str * "\nobject_names: $(GW.get_object_names(env))"
148151
print(io, str)
149152
return nothing
150153
end
151154

155+
GW.get_action_keys(env::Catcher) = ('a', 'd', 's')
156+
152157
#####
153158
##### RLBase API
154159
#####
@@ -159,7 +164,7 @@ RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: Catcher}
159164

160165
RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: Catcher} = GW.reset!(env.env)
161166

162-
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: Catcher} = 1:NUM_ACTIONS
167+
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: Catcher} = Base.OneTo(NUM_ACTIONS)
163168
(env::GW.RLBaseEnv{E})(action) where {E <: Catcher} = GW.act!(env.env, action)
164169

165170
RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: Catcher} = env.env.reward

src/envs/collect_gems_directed.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const WALL = CGUM.WALL
1515
const GEM = CGUM.GEM
1616
const NUM_ACTIONS = 4
1717

18-
mutable struct CollectGemsDirected{R, RNG} <: GW.AbstractGridWorldGame
18+
mutable struct CollectGemsDirected{R, RNG} <: GW.AbstractGridWorld
1919
env::CGUM.CollectGemsUndirected{R, RNG}
2020
agent_direction::Int
2121
end
@@ -78,32 +78,58 @@ end
7878
##### miscellaneous
7979
#####
8080

81-
CHARACTERS = ('', '', '', '', '', '', '', '')
81+
GW.get_height(env::CollectGemsDirected) = GW.get_height(env.env)
82+
GW.get_width(env::CollectGemsDirected) = GW.get_width(env.env)
8283

83-
GW.get_tile_map_height(env::CollectGemsDirected) = size(env.env.tile_map, 2)
84-
GW.get_tile_map_width(env::CollectGemsDirected) = size(env.env.tile_map, 3)
84+
GW.get_action_names(env::CollectGemsDirected) = (:MOVE_FORWARD, :MOVE_BACKWARD, :TURN_LEFT, :TURN_RIGHT)
85+
GW.get_object_names(env::CollectGemsDirected) = GW.get_object_names(env.env)
86+
87+
function GW.get_pretty_tile_map(env::CollectGemsDirected, position::CartesianIndex{2})
88+
characters = ('', '', '', '', '', '', '', '')
8589

86-
function GW.get_tile_pretty_repr(env::CollectGemsDirected, i::Integer, j::Integer)
87-
object = findfirst(@view env.env.tile_map[:, i, j])
90+
object = findfirst(@view env.env.tile_map[:, position])
8891
if isnothing(object)
89-
return CHARACTERS[end]
92+
return characters[end]
9093
elseif object == AGENT
91-
return CHARACTERS[NUM_OBJECTS + 1 + env.agent_direction]
94+
return characters[NUM_OBJECTS + 1 + env.agent_direction]
9295
else
93-
return CHARACTERS[object]
96+
return characters[object]
9497
end
9598
end
9699

97-
GW.get_action_keys(env::CollectGemsDirected) = ('w', 's', 'a', 'd')
98-
GW.get_action_names(env::CollectGemsDirected) = (:MOVE_FORWARD, :MOVE_BACKWARD, :TURN_LEFT, :TURN_RIGHT)
100+
function GW.get_pretty_sub_tile_map(env::CollectGemsDirected, window_size, position::CartesianIndex{2})
101+
tile_map = env.env.tile_map
102+
agent_position = env.env.agent_position
103+
agent_direction = env.agent_direction
104+
105+
characters = ('', '', '', '', '', '', '', '')
106+
107+
sub_tile_map = GW.get_sub_tile_map(tile_map, agent_position, window_size, agent_direction)
108+
109+
object = findfirst(@view sub_tile_map[:, position])
110+
if isnothing(object)
111+
return characters[end]
112+
elseif object == AGENT
113+
return ''
114+
else
115+
return characters[object]
116+
end
117+
end
99118

100119
function Base.show(io::IO, ::MIME"text/plain", env::CollectGemsDirected)
101-
str = GW.get_tile_map_pretty_repr(env)
102-
str = str * "\nreward = $(env.env.reward)\ndone = $(env.env.done)"
120+
str = "tile_map:\n"
121+
str = str * GW.get_pretty_tile_map(env)
122+
str = str * "\nsub_tile_map:\n"
123+
str = str * GW.get_pretty_sub_tile_map(env, GW.get_window_size(env))
124+
str = str * "\nreward: $(env.env.reward)\ndone: $(env.env.done)"
125+
str = str * "\naction_names: $(GW.get_action_names(env))"
126+
str = str * "\nobject_names: $(GW.get_object_names(env))"
103127
print(io, str)
104128
return nothing
105129
end
106130

131+
GW.get_action_keys(env::CollectGemsDirected) = ('w', 's', 'a', 'd')
132+
107133
#####
108134
##### RLBase API
109135
#####
@@ -114,7 +140,7 @@ RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: CollectGe
114140

115141
RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = GW.reset!(env.env)
116142

117-
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = 1:NUM_ACTIONS
143+
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = Base.OneTo(NUM_ACTIONS)
118144
(env::GW.RLBaseEnv{E})(action) where {E <: CollectGemsDirected} = GW.act!(env.env, action)
119145

120146
RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = env.env.env.reward

src/envs/collect_gems_multi_agent_undirected.jl

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import ReinforcementLearningBase as RLBase
1010

1111
const NUM_ACTIONS = 4
1212

13-
mutable struct CollectGemsMultiAgentUndirected{R, RNG} <: GW.AbstractGridWorldGame
13+
mutable struct CollectGemsMultiAgentUndirected{R, RNG} <: GW.AbstractGridWorld
1414
tile_map::BitArray{3}
1515
agent_positions::Vector{CartesianIndex{2}}
1616
current_agent::Int
@@ -97,7 +97,7 @@ function GW.reset!(env::CollectGemsMultiAgentUndirected)
9797
end
9898

9999
function GW.act!(env::CollectGemsMultiAgentUndirected, action)
100-
@assert action in 1:NUM_ACTIONS "Invalid action $(action)"
100+
@assert action in Base.OneTo(NUM_ACTIONS) "Invalid action $(action)"
101101

102102
tile_map = env.tile_map
103103
agent_positions = env.agent_positions
@@ -148,12 +148,26 @@ end
148148
##### miscellaneous
149149
#####
150150

151-
GW.get_tile_map_height(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 2)
152-
GW.get_tile_map_width(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 3)
151+
GW.get_height(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 2)
152+
GW.get_width(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 3)
153153

154-
function GW.get_tile_pretty_repr(env::CollectGemsMultiAgentUndirected, i::Integer, j::Integer)
154+
GW.get_action_names(env::CollectGemsMultiAgentUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
155+
156+
function GW.get_object_names(env::CollectGemsMultiAgentUndirected)
157+
num_agents = length(env.agent_positions)
158+
object_names = Array{Symbol}(undef, num_agents + 2)
159+
for i in 1:num_agents
160+
object_names[i] = Symbol("AGENT", "$(i)")
161+
end
162+
object_names[end - 1] = :WALL
163+
object_names[end] = :GEM
164+
165+
return object_names
166+
end
167+
168+
function GW.get_pretty_tile_map(env::CollectGemsMultiAgentUndirected, position::CartesianIndex{2})
155169
tile_map = env.tile_map
156-
object = findfirst(@view tile_map[:, i, j])
170+
object = findfirst(@view tile_map[:, position])
157171
num_agents = size(tile_map, 1) - 2
158172

159173
if isnothing(object)
@@ -167,16 +181,40 @@ function GW.get_tile_pretty_repr(env::CollectGemsMultiAgentUndirected, i::Intege
167181
end
168182
end
169183

170-
GW.get_action_keys(env::CollectGemsMultiAgentUndirected) = ('w', 's', 'a', 'd')
171-
GW.get_action_names(env::CollectGemsMultiAgentUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
184+
function GW.get_pretty_sub_tile_map(env::CollectGemsMultiAgentUndirected, window_size, position::CartesianIndex{2})
185+
tile_map = env.tile_map
186+
agent_positions = env.agent_positions
187+
agent_position = agent_positions[env.current_agent]
188+
num_agents = length(agent_positions)
189+
190+
sub_tile_map = GW.get_sub_tile_map(tile_map, agent_position, window_size)
191+
192+
object = findfirst(@view sub_tile_map[:, position])
193+
if isnothing(object)
194+
return ""
195+
elseif object in 1 : num_agents
196+
return "$(object)"
197+
elseif object == num_agents + 1
198+
return ""
199+
else
200+
return ""
201+
end
202+
end
172203

173204
function Base.show(io::IO, ::MIME"text/plain", env::CollectGemsMultiAgentUndirected)
174-
str = GW.get_tile_map_pretty_repr(env)
175-
str = str * "\nreward = $(env.reward)\ndone = $(env.done)\ncurrent_agent = $(env.current_agent)"
205+
str = "tile_map:\n"
206+
str = str * GW.get_pretty_tile_map(env)
207+
str = str * "\nsub_tile_map:\n"
208+
str = str * GW.get_pretty_sub_tile_map(env, GW.get_window_size(env))
209+
str = str * "\nreward: $(env.reward)\ndone: $(env.done)\ncurrent_agent = $(env.current_agent)"
210+
str = str * "\naction_names: $(GW.get_action_names(env))"
211+
str = str * "\nobject_names: $(GW.get_object_names(env))"
176212
print(io, str)
177213
return nothing
178214
end
179215

216+
GW.get_action_keys(env::CollectGemsMultiAgentUndirected) = ('w', 's', 'a', 'd')
217+
180218
#####
181219
##### RLBase API
182220
#####
@@ -187,7 +225,7 @@ RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: CollectGe
187225

188226
RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = GW.reset!(env.env)
189227

190-
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = 1:NUM_ACTIONS
228+
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = Base.OneTo(NUM_ACTIONS)
191229
(env::GW.RLBaseEnv{E})(action) where {E <: CollectGemsMultiAgentUndirected} = GW.act!(env.env, action)
192230

193231
RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = env.env.reward

0 commit comments

Comments
 (0)