Skip to content

Commit f05079c

Browse files
authored
rearrange codebase (#173)
* remove objects.jl * move directions.jl into actions.jl * rename actions.jl to navigation.jl * move sampling methods from envs.jl to abstract_grid_world.jl * better organize SingleRoom envs * reorganize GridRoom envs * reorganize SequentialRoom envs * reorganize Maze envs * reorganize GoToTarget envs * reorganize DoorKey envs * reorganize CollectGems envs * reorganize DynamicObstacles envs * reorganize Sokoban envs * reorganize Catcher * reorganize Snake * reorganize Transport envs * move RLBase API into respective env modules * minor cleanup
1 parent a2f472b commit f05079c

31 files changed

+1283
-1125
lines changed

src/GridWorlds.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ import Random
55
import REPL
66
import ReinforcementLearningBase as RLBase
77

8-
include("directions.jl")
9-
include("actions.jl")
10-
include("objects.jl")
8+
include("navigation.jl")
119
include("abstract_grid_world.jl")
1210
include("play.jl")
13-
include("envs/envs.jl")
1411
include("rlbase.jl")
12+
include("envs/envs.jl")
1513

1614
end

src/abstract_grid_world.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ function get_sub_tile_map_pretty_repr(env::AbstractGridWorldGame, window_size)
5959
return str
6060
end
6161

62-
6362
#####
6463
##### Sub tile map
6564
#####
@@ -155,3 +154,35 @@ function get_sub_tile_map!(sub_tile_map, tile_map, position, window_size, direct
155154

156155
return nothing
157156
end
157+
158+
#####
159+
##### sampling tile map positions
160+
#####
161+
162+
function sample_empty_position(rng, tile_map, max_tries = 1024)
163+
_, height, width = size(tile_map)
164+
position = CartesianIndex(rand(rng, 1:height), rand(rng, 1:width))
165+
166+
for i in 1:1000
167+
if any(@view tile_map[:, position])
168+
position = CartesianIndex(rand(rng, 1:height), rand(rng, 1:width))
169+
else
170+
return position
171+
end
172+
end
173+
174+
@warn "Returning non-empty position: $(position)"
175+
176+
return position
177+
end
178+
179+
function sample_two_positions_without_replacement(rng, region)
180+
position1 = rand(rng, region)
181+
position2 = rand(rng, region)
182+
183+
while position1 == position2
184+
position2 = rand(rng, region)
185+
end
186+
187+
return position1, position2
188+
end

src/directions.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/envs/catcher.jl

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ module CatcherModule
22

33
import ..GridWorlds as GW
44
import Random
5+
import ReinforcementLearningBase as RLBase
6+
7+
#####
8+
##### game logic
9+
#####
10+
11+
const NUM_OBJECTS = 2
12+
const AGENT = 1
13+
const GEM = 2
14+
const NUM_ACTIONS = 3
515

616
mutable struct Catcher{R, RNG} <: GW.AbstractGridWorldGame
717
tile_map::BitArray{3}
@@ -14,28 +24,6 @@ mutable struct Catcher{R, RNG} <: GW.AbstractGridWorldGame
1424
terminal_penalty::R
1525
end
1626

17-
const NUM_OBJECTS = 2
18-
const AGENT = 1
19-
const GEM = 2
20-
21-
CHARACTERS = ('', '', '')
22-
23-
GW.get_tile_map_height(env::Catcher) = size(env.tile_map, 2)
24-
GW.get_tile_map_width(env::Catcher) = size(env.tile_map, 3)
25-
26-
function GW.get_tile_pretty_repr(env::Catcher, i::Integer, j::Integer)
27-
object = findfirst(@view env.tile_map[:, i, j])
28-
if isnothing(object)
29-
return CHARACTERS[end]
30-
else
31-
return CHARACTERS[object]
32-
end
33-
end
34-
35-
const NUM_ACTIONS = 3
36-
GW.get_action_keys(env::Catcher) = ('a', 'd', 's')
37-
GW.get_action_names(env::Catcher) = (:MOVE_LEFT, :MOVE_RIGHT, :NO_MOVE)
38-
3927
function Catcher(; R = Float32, height = 8, width = 8, rng = Random.GLOBAL_RNG)
4028
tile_map = falses(NUM_OBJECTS, height, width)
4129

@@ -133,11 +121,48 @@ function GW.act!(env::Catcher, action)
133121
return nothing
134122
end
135123

124+
#####
125+
##### miscellaneous
126+
#####
127+
128+
CHARACTERS = ('', '', '')
129+
130+
GW.get_tile_map_height(env::Catcher) = size(env.tile_map, 2)
131+
GW.get_tile_map_width(env::Catcher) = size(env.tile_map, 3)
132+
133+
function GW.get_tile_pretty_repr(env::Catcher, i::Integer, j::Integer)
134+
object = findfirst(@view env.tile_map[:, i, j])
135+
if isnothing(object)
136+
return CHARACTERS[end]
137+
else
138+
return CHARACTERS[object]
139+
end
140+
end
141+
142+
GW.get_action_keys(env::Catcher) = ('a', 'd', 's')
143+
GW.get_action_names(env::Catcher) = (:MOVE_LEFT, :MOVE_RIGHT, :NO_MOVE)
144+
136145
function Base.show(io::IO, ::MIME"text/plain", env::Catcher)
137146
str = GW.get_tile_map_pretty_repr(env)
138147
str = str * "\nreward = $(env.reward)\ndone = $(env.done)"
139148
print(io, str)
140149
return nothing
141150
end
142151

152+
#####
153+
##### RLBase API
154+
#####
155+
156+
RLBase.StateStyle(env::GW.RLBaseEnv{E}) where {E <: Catcher} = RLBase.InternalState{Any}()
157+
RLBase.state_space(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: Catcher} = nothing
158+
RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: Catcher} = env.env.tile_map
159+
160+
RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: Catcher} = GW.reset!(env.env)
161+
162+
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: Catcher} = 1:NUM_ACTIONS
163+
(env::GW.RLBaseEnv{E})(action) where {E <: Catcher} = GW.act!(env.env, action)
164+
165+
RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: Catcher} = env.env.reward
166+
RLBase.is_terminated(env::GW.RLBaseEnv{E}) where {E <: Catcher} = env.env.done
167+
143168
end # module

src/envs/collect_gems_directed.jl

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,23 @@ module CollectGemsDirectedModule
33
import ..CollectGemsUndirectedModule as CGUM
44
import ..GridWorlds as GW
55
import Random
6+
import ReinforcementLearningBase as RLBase
67

7-
mutable struct CollectGemsDirected{R, RNG} <: GW.AbstractGridWorldGame
8-
env::CGUM.CollectGemsUndirected{R, RNG}
9-
agent_direction::Int
10-
end
8+
#####
9+
##### game logic
10+
#####
1111

1212
const NUM_OBJECTS = CGUM.NUM_OBJECTS
1313
const AGENT = CGUM.AGENT
1414
const WALL = CGUM.WALL
1515
const GEM = CGUM.GEM
16+
const NUM_ACTIONS = 4
1617

17-
CHARACTERS = ('', '', '', '', '', '', '', '')
18-
19-
GW.get_tile_map_height(env::CollectGemsDirected) = size(env.env.tile_map, 2)
20-
GW.get_tile_map_width(env::CollectGemsDirected) = size(env.env.tile_map, 3)
21-
22-
function GW.get_tile_pretty_repr(env::CollectGemsDirected, i::Integer, j::Integer)
23-
object = findfirst(@view env.env.tile_map[:, i, j])
24-
if isnothing(object)
25-
return CHARACTERS[end]
26-
elseif object == AGENT
27-
return CHARACTERS[NUM_OBJECTS + 1 + env.agent_direction]
28-
else
29-
return CHARACTERS[object]
30-
end
18+
mutable struct CollectGemsDirected{R, RNG} <: GW.AbstractGridWorldGame
19+
env::CGUM.CollectGemsUndirected{R, RNG}
20+
agent_direction::Int
3121
end
3222

33-
const NUM_ACTIONS = 4
34-
GW.get_action_keys(env::CollectGemsDirected) = ('w', 's', 'a', 'd')
35-
GW.get_action_names(env::CollectGemsDirected) = (:MOVE_FORWARD, :MOVE_BACKWARD, :TURN_LEFT, :TURN_RIGHT)
36-
3723
function CollectGemsDirected(; R = Float32, height = 8, width = 8, num_gem_init = floor(Int, sqrt(height * width)), rng = Random.GLOBAL_RNG)
3824
env = CGUM.CollectGemsUndirected(R = R, height = height, width = width, num_gem_init = num_gem_init, rng = rng)
3925
agent_direction = rand(rng, 0:GW.NUM_DIRECTIONS-1)
@@ -88,11 +74,50 @@ function GW.act!(env::CollectGemsDirected, action)
8874
return nothing
8975
end
9076

77+
#####
78+
##### miscellaneous
79+
#####
80+
81+
CHARACTERS = ('', '', '', '', '', '', '', '')
82+
83+
GW.get_tile_map_height(env::CollectGemsDirected) = size(env.env.tile_map, 2)
84+
GW.get_tile_map_width(env::CollectGemsDirected) = size(env.env.tile_map, 3)
85+
86+
function GW.get_tile_pretty_repr(env::CollectGemsDirected, i::Integer, j::Integer)
87+
object = findfirst(@view env.env.tile_map[:, i, j])
88+
if isnothing(object)
89+
return CHARACTERS[end]
90+
elseif object == AGENT
91+
return CHARACTERS[NUM_OBJECTS + 1 + env.agent_direction]
92+
else
93+
return CHARACTERS[object]
94+
end
95+
end
96+
97+
GW.get_action_keys(env::CollectGemsDirected) = ('w', 's', 'a', 'd')
98+
GW.get_action_names(env::CollectGemsDirected) = (:MOVE_FORWARD, :MOVE_BACKWARD, :TURN_LEFT, :TURN_RIGHT)
99+
91100
function Base.show(io::IO, ::MIME"text/plain", env::CollectGemsDirected)
92101
str = GW.get_tile_map_pretty_repr(env)
93102
str = str * "\nreward = $(env.env.reward)\ndone = $(env.env.done)"
94103
print(io, str)
95104
return nothing
96105
end
97106

107+
#####
108+
##### RLBase API
109+
#####
110+
111+
RLBase.StateStyle(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = RLBase.InternalState{Any}()
112+
RLBase.state_space(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: CollectGemsDirected} = nothing
113+
RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: CollectGemsDirected} = (env.env.env.tile_map, env.env.agent_direction)
114+
115+
RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = GW.reset!(env.env)
116+
117+
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = 1:NUM_ACTIONS
118+
(env::GW.RLBaseEnv{E})(action) where {E <: CollectGemsDirected} = GW.act!(env.env, action)
119+
120+
RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = env.env.env.reward
121+
RLBase.is_terminated(env::GW.RLBaseEnv{E}) where {E <: CollectGemsDirected} = env.env.env.done
122+
98123
end # module

src/envs/collect_gems_multi_agent_undirected.jl

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ module CollectGemsMultiAgentUndirectedModule
22

33
import ..GridWorlds as GW
44
import Random
5+
import ReinforcementLearningBase as RLBase
6+
7+
#####
8+
##### game logic
9+
#####
10+
11+
const NUM_ACTIONS = 4
512

613
mutable struct CollectGemsMultiAgentUndirected{R, RNG} <: GW.AbstractGridWorldGame
714
tile_map::BitArray{3}
@@ -16,29 +23,6 @@ mutable struct CollectGemsMultiAgentUndirected{R, RNG} <: GW.AbstractGridWorldGa
1623
gem_positions::Vector{CartesianIndex{2}}
1724
end
1825

19-
GW.get_tile_map_height(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 2)
20-
GW.get_tile_map_width(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 3)
21-
22-
function GW.get_tile_pretty_repr(env::CollectGemsMultiAgentUndirected, i::Integer, j::Integer)
23-
tile_map = env.tile_map
24-
object = findfirst(@view tile_map[:, i, j])
25-
num_agents = size(tile_map, 1) - 2
26-
27-
if isnothing(object)
28-
return ""
29-
elseif object in 1 : num_agents
30-
return "$(object)"
31-
elseif object == num_agents + 1
32-
return ""
33-
else
34-
return ""
35-
end
36-
end
37-
38-
const NUM_ACTIONS = 4
39-
GW.get_action_keys(env::CollectGemsMultiAgentUndirected) = ('w', 's', 'a', 'd')
40-
GW.get_action_names(env::CollectGemsMultiAgentUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
41-
4226
function CollectGemsMultiAgentUndirected(; R = Float32, height = 8, width = 8, num_gem_init = floor(Int, sqrt(height * width)), num_agents = 4, rng = Random.GLOBAL_RNG)
4327
tile_map = falses(num_agents + 2, height, width)
4428
WALL = num_agents + 1
@@ -160,11 +144,53 @@ function GW.act!(env::CollectGemsMultiAgentUndirected, action)
160144
return nothing
161145
end
162146

147+
#####
148+
##### miscellaneous
149+
#####
150+
151+
GW.get_tile_map_height(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 2)
152+
GW.get_tile_map_width(env::CollectGemsMultiAgentUndirected) = size(env.tile_map, 3)
153+
154+
function GW.get_tile_pretty_repr(env::CollectGemsMultiAgentUndirected, i::Integer, j::Integer)
155+
tile_map = env.tile_map
156+
object = findfirst(@view tile_map[:, i, j])
157+
num_agents = size(tile_map, 1) - 2
158+
159+
if isnothing(object)
160+
return ""
161+
elseif object in 1 : num_agents
162+
return "$(object)"
163+
elseif object == num_agents + 1
164+
return ""
165+
else
166+
return ""
167+
end
168+
end
169+
170+
GW.get_action_keys(env::CollectGemsMultiAgentUndirected) = ('w', 's', 'a', 'd')
171+
GW.get_action_names(env::CollectGemsMultiAgentUndirected) = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT)
172+
163173
function Base.show(io::IO, ::MIME"text/plain", env::CollectGemsMultiAgentUndirected)
164174
str = GW.get_tile_map_pretty_repr(env)
165175
str = str * "\nreward = $(env.reward)\ndone = $(env.done)\ncurrent_agent = $(env.current_agent)"
166176
print(io, str)
167177
return nothing
168178
end
169179

180+
#####
181+
##### RLBase API
182+
#####
183+
184+
RLBase.StateStyle(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = RLBase.InternalState{Any}()
185+
RLBase.state_space(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: CollectGemsMultiAgentUndirected} = nothing
186+
RLBase.state(env::GW.RLBaseEnv{E}, ::RLBase.InternalState) where {E <: CollectGemsMultiAgentUndirected} = env.env.tile_map
187+
188+
RLBase.reset!(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = GW.reset!(env.env)
189+
190+
RLBase.action_space(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = 1:NUM_ACTIONS
191+
(env::GW.RLBaseEnv{E})(action) where {E <: CollectGemsMultiAgentUndirected} = GW.act!(env.env, action)
192+
193+
RLBase.reward(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = env.env.reward
194+
RLBase.is_terminated(env::GW.RLBaseEnv{E}) where {E <: CollectGemsMultiAgentUndirected} = env.env.done
195+
170196
end # module

0 commit comments

Comments
 (0)