diff --git a/.gitignore b/.gitignore index 622cef2..24cb5fb 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ Manifest.toml # vim temporary files *~ *.swp + +/src/scratchpad.jl + +/benchmark/20* diff --git a/Project.toml b/Project.toml index ea30e6b..5e5a68c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] Crayons = "4.0" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 2f876ff..dd589b6 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,5 +1,6 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" GridWorlds = "e15a9946-cd7f-4d03-83e2-6c30bacb0043" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" diff --git a/benchmark/benchmark_utils.jl b/benchmark/benchmark_utils.jl new file mode 100644 index 0000000..8e21b2d --- /dev/null +++ b/benchmark/benchmark_utils.jl @@ -0,0 +1,219 @@ +import BenchmarkTools as BT +import DataStructures as DS +import Dates +import GridWorlds as GW +import ReinforcementLearningBase as RLBase +import Statistics + +ENVS = [GW.ModuleSingleRoomUndirected.SingleRoomUndirected] +BATCH_ENVS = [GW.ModuleSingleRoomUndirectedBatch.SingleRoomUndirectedBatch] + +function run_random_policy_env!(env, num_resets, steps_per_reset) + for _ in 1:num_resets + RLBase.reset!(env) + for _ in 1:steps_per_reset + state = RLBase.state(env) + action = rand(RLBase.action_space(env)) + env(action) + is_terminated = RLBase.is_terminated(env) + reward = RLBase.reward(env) + end + end + + return nothing +end + +function run_random_policy_batch_env!(env, num_resets, steps_per_reset) + num_envs = size(env.tile_map, 4) + action = Array{eltype(RLBase.action_space(env))}(undef, num_envs) + for _ in 1:num_resets + RLBase.reset!(env, force = true) + for _ in 1:steps_per_reset + state = RLBase.state(env) + for i in 1:num_envs + action[i] = rand(RLBase.action_space(env)) + end + env(action) + is_terminated = RLBase.is_terminated(env) + reward = RLBase.reward(env) + end + end + + return nothing +end + +# function compile_envs(Envs, num_resets, steps_per_reset) + # for Env in Envs + # env = Env() + # run_random_policy!(env, num_resets, steps_per_reset) + # end + + # @info "Compiled and ran all environments" + + # return nothing +# end + +function benchmark_env(Env, num_resets, steps_per_reset) + benchmark = DS.OrderedDict() + + parent_module = parentmodule(Env) + + env = Env() + + benchmark[:random_policy] = BT.@benchmark run_random_policy_env!($(Ref(env))[], $(Ref(num_resets))[], $(Ref(steps_per_reset))[]) + benchmark[:reset] = BT.@benchmark RLBase.reset!($(Ref(env))[]) + benchmark[:state] = BT.@benchmark RLBase.state($(Ref(env))[]) + + for action in RLBase.action_space(env) + action_name = parent_module.ACTION_NAMES[action] + benchmark[action_name] = BT.@benchmark $(Ref(env))[]($(Ref(action))[]) + end + + benchmark[:action_space] = BT.@benchmark RLBase.action_space($(Ref(env))[]) + benchmark[:is_terminated] = BT.@benchmark RLBase.is_terminated($(Ref(env))[]) + benchmark[:reward] = BT.@benchmark RLBase.reward($(Ref(env))[]) + + @info "$(nameof(Env)) benchmarked" + + return benchmark +end + +function benchmark_batch_env(Env, num_resets, steps_per_reset, num_envs) + benchmark = DS.OrderedDict() + + parent_module = parentmodule(Env) + + env = Env(num_envs = num_envs) + + benchmark[:random_policy] = BT.@benchmark run_random_policy_batch_env!($(Ref(env))[], $(Ref(num_resets))[], $(Ref(steps_per_reset))[]) + benchmark[:reset] = BT.@benchmark RLBase.reset!($(Ref(env))[], force = true) + benchmark[:state] = BT.@benchmark RLBase.state($(Ref(env))[]) + + for action in RLBase.action_space(env) + action_name = parent_module.ACTION_NAMES[action] + batch_action = fill(action, num_envs) + benchmark[action_name] = BT.@benchmark $(Ref(env))[]($(Ref(batch_action))[]) + end + + benchmark[:action_space] = BT.@benchmark RLBase.action_space($(Ref(env))[]) + benchmark[:is_terminated] = BT.@benchmark RLBase.is_terminated($(Ref(env))[]) + benchmark[:reward] = BT.@benchmark RLBase.reward($(Ref(env))[]) + + @info "$(nameof(Env)) benchmarked" + + return benchmark +end + +function benchmark_envs(Envs, num_resets, steps_per_reset) + benchmarks = DS.OrderedDict() + + for Env in Envs + benchmarks[nameof(Env)] = benchmark_env(Env, num_resets, steps_per_reset) + end + + @info "benchmark_envs complete" + + return benchmarks +end + +function benchmark_batch_envs(Envs, num_resets, steps_per_reset, num_envs) + benchmarks = DS.OrderedDict() + + for Env in Envs + benchmarks[nameof(Env)] = benchmark_batch_env(Env, num_resets, steps_per_reset, num_envs) + end + + @info "benchmark_batch_envs complete" + + return benchmarks +end + +function get_summary(trial::BT.Trial) + median_trial = BT.median(trial) + memory = BT.prettymemory(median_trial.memory) + median_time = BT.prettytime(median_trial.time) + return memory, median_time +end + +function get_table(benchmark) + title = "|" + separator = "|" + data = "|" + + for key in keys(benchmark) + title = title * String(key) * "|" + separator = separator * ":---:|" + memory, median_time = get_summary(benchmark[key]) + data = data * "$(memory)
$(median_time)|" + end + + return title, separator, data +end + +function generate_benchmark_file(benchmarks; file_name = nothing) + date = Dates.format(Dates.now(), "yyyy_mm_dd_HH_MM_SS") + + if isnothing(file_name) + file_name = date * ".md" + end + + io = open(file_name, "w") + + println(io, "Date: $(date)") + println(io, "## List of Environments") + + for key in keys(benchmarks) + name_string = String(key) + println(io, " 1. [$(name_string)](#$(lowercase(name_string)))") + end + + println(io) + + for key in keys(benchmarks) + println(io, "### " * String(key)) + title, separator, data = get_table(benchmarks[key]) + println(io, title) + println(io, separator) + println(io, data) + println(io) + end + + close(io) + + return nothing +end + +# function generate_benchmark_file_batch_envs(Envs, num_resets, steps_per_reset, num_envs; file_name = nothing) + # date = Dates.format(Dates.now(), "yyyy_mm_dd_HH_MM_SS") + + # if isnothing(file_name) + # file_name = date * ".md" + # end + + # io = open(file_name, "w") + + # benchmarks = benchmark_batch_envs(Envs, num_resets, steps_per_reset, num_envs) + + # println(io, "Date: $(date)") + # println(io, "## List of Environments") + + # for Env in Envs + # name_string = String(nameof(Env)) + # println(io, " 1. [$(name_string)](#$(lowercase(name_string)))") + # end + + # println(io) + + # for key in keys(benchmarks) + # println(io, "### " * String(key)) + # title, separator, data = get_table(benchmarks[key]) + # println(io, title) + # println(io, separator) + # println(io, data) + # println(io) + # end + + # close(io) + + # return nothing +# end diff --git a/src/GridWorlds.jl b/src/GridWorlds.jl index d77d239..7cb077b 100644 --- a/src/GridWorlds.jl +++ b/src/GridWorlds.jl @@ -19,6 +19,7 @@ include("actions.jl") include("objects.jl") include("grid_world_base.jl") include("abstract_grid_world.jl") +include("play.jl") include("envs/envs.jl") include("textual_rendering.jl") diff --git a/src/envs/envs.jl b/src/envs/envs.jl index 730d069..c9efa9e 100644 --- a/src/envs/envs.jl +++ b/src/envs/envs.jl @@ -42,3 +42,5 @@ include("snake.jl") include("catcher.jl") include("transport.jl") include("collect_gems_undirected_multi_agent.jl") +include("single_room_undirected_batch.jl") +include("single_room_undirected.jl") diff --git a/src/envs/single_room_undirected.jl b/src/envs/single_room_undirected.jl new file mode 100644 index 0000000..7be382f --- /dev/null +++ b/src/envs/single_room_undirected.jl @@ -0,0 +1,242 @@ +module ModuleSingleRoomUndirected + +import Crayons +import ..GridWorlds as GW +import ..Play +import Random +import REPL +import ReinforcementLearningBase as RLBase +import StaticArrays as SA +import StatsBase as SB + +const MOVE_UP = 1 +const MOVE_DOWN = 2 +const MOVE_LEFT = 3 +const MOVE_RIGHT = 4 +const ACTION_NAMES = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT) + +const AGENT = 1 +const WALL = 2 +const GOAL = 3 + +const DUMMY_CHARACTER = '⋅' +const CHARACTERS = ('☻', '█', '♥') +const FOREGROUND_COLORS = (:light_red, :white, :light_red) + +function sample_two_positions_without_replacement(rng, region) + position1 = rand(rng, region) + position2 = rand(rng, region) + + while position1 == position2 + position2 = rand(rng, region) + end + + return position1, position2 +end + +function move(action::Integer, i, j) + if action == MOVE_UP + return i - 1, j + elseif action == MOVE_DOWN + return i + 1, j + elseif action == MOVE_LEFT + return i, j - 1 + elseif action == MOVE_RIGHT + return i, j + 1 + else + return i, j + end +end + +mutable struct SingleRoomUndirected{R, RNG} <: GW.AbstractGridWorld + tile_map::BitArray{3} + agent_position::CartesianIndex{2} + reward::R + rng::RNG + done::Bool + terminal_reward::R + goal_position::CartesianIndex{2} +end + +function SingleRoomUndirected(; R = Float32, height = 8, width = 8, rng = Random.MersenneTwister()) + tile_map = BitArray(undef, 3, height, width) + + inner_area = CartesianIndices((2 : height - 1, 2 : width - 1)) + + tile_map[:, :, :] .= false + tile_map[WALL, 1, :] .= true + tile_map[WALL, height, :] .= true + tile_map[WALL, :, 1] .= true + tile_map[WALL, :, width] .= true + + agent_position, goal_position = sample_two_positions_without_replacement(rng, inner_area) + + tile_map[AGENT, agent_position] = true + tile_map[GOAL, goal_position] = true + + reward = zero(R) + done = false + terminal_reward = one(R) + + env = SingleRoomUndirected(tile_map, agent_position, reward, rng, done, terminal_reward, goal_position) + + RLBase.reset!(env) + + return env +end + +RLBase.StateStyle(env::SingleRoomUndirected) = RLBase.InternalState{Any}() +RLBase.state_space(env::SingleRoomUndirected, ::RLBase.InternalState) = nothing +RLBase.state(env::SingleRoomUndirected, ::RLBase.InternalState) = env.tile_map + +RLBase.action_space(env::SingleRoomUndirected) = (MOVE_UP, MOVE_DOWN, MOVE_LEFT, MOVE_RIGHT) +RLBase.reward(env::SingleRoomUndirected) = env.reward[] +RLBase.is_terminated(env::SingleRoomUndirected) = env.done[] + +function RLBase.reset!(env::SingleRoomUndirected{R}) where {R} + tile_map = env.tile_map + rng = env.rng + + num_objects, height, width = size(tile_map) + inner_area = CartesianIndices((2 : height - 1, 2 : width - 1)) + + tile_map[AGENT, env.agent_position] = false + tile_map[GOAL, env.goal_position] = false + + new_agent_position, new_goal_position = sample_two_positions_without_replacement(rng, inner_area) + + env.agent_position = new_agent_position + tile_map[AGENT, new_agent_position] = true + + env.goal_position = new_goal_position + tile_map[GOAL, new_goal_position] = true + + env.reward = zero(R) + env.done = false + + return nothing +end + +function (env::SingleRoomUndirected{R})(action) where {R} + tile_map = env.tile_map + agent_position = env.agent_position + + new_agent_position = CartesianIndex(move(action, agent_position.I...)) + + if !tile_map[WALL, new_agent_position] + tile_map[AGENT, agent_position] = false + env.agent_position = new_agent_position + tile_map[AGENT, new_agent_position] = true + end + + if tile_map[GOAL, env.agent_position] + env.reward = env.terminal_reward + done = true + else + env.reward = zero(R) + done = false + end + + return nothing +end + +function Base.show(io::IO, ::MIME"text/plain", env::SingleRoomUndirected) + tile_map = env.tile_map + + num_objects, height, width = size(tile_map) + + print(io, "objects = ") + for i in 1 : length(CHARACTERS) + print(io, Crayons.Crayon(foreground = FOREGROUND_COLORS[i]), CHARACTERS[i], Crayons.Crayon(reset = true)) + if i < length(CHARACTERS) + print(io, ", ") + else + print(io, "\n") + end + end + println(io, "dummy character = ", DUMMY_CHARACTER) + + println(io) + for i in 1:height + for j in 1:width + idx = findfirst(@view tile_map[:, i, j]) + if isnothing(idx) + print(io, DUMMY_CHARACTER) + else + print(io, Crayons.Crayon(foreground = FOREGROUND_COLORS[idx]), CHARACTERS[idx], Crayons.Crayon(reset = true)) + end + end + + println(io) + end + + println(io, "reward = ", env.reward) + println(io, "done = ", env.done) + + return nothing +end + +get_string_key_bindings(env::SingleRoomUndirected) = """Key bindings: + 'q': quit + 'r': RLBase.reset!(env) + 'w': MOVE_UP + 's': MOVE_DOWN + 'a': MOVE_LEFT + 'd': MOVE_RIGHT + """ + +function play!(terminal::REPL.Terminals.UnixTerminal, env::SingleRoomUndirected; file_name::Union{Nothing, AbstractString} = nothing) + REPL.Terminals.raw!(terminal, true) + + terminal_out = terminal.out_stream + terminal_in = terminal.in_stream + file = Play.open_maybe(file_name) + + Play.write_io1_maybe_io2(terminal_out, file, Play.CLEAR_SCREEN) + Play.write_io1_maybe_io2(terminal_out, file, Play.MOVE_CURSOR_TO_ORIGIN) + Play.write_io1_maybe_io2(terminal_out, file, Play.HIDE_CURSOR) + + action_chars = ('w', 's', 'a', 'd') + + char_to_action = Dict('w' => MOVE_UP, + 's' => MOVE_DOWN, + 'a' => MOVE_LEFT, + 'd' => MOVE_RIGHT, + ) + + try + while true + Play.write_io1_maybe_io2(terminal_out, file, get_string_key_bindings(env)) + Play.show_io1_maybe_io2(terminal_out, file, MIME("text/plain"), env) + + char = read(terminal_in, Char) + + Play.write_io1_maybe_io2(terminal_out, file, Play.EMPTY_SCREEN) + + if char == 'q' + Play.write_io1_maybe_io2(terminal_out, file, Play.SHOW_CURSOR) + Play.close_maybe(file) + REPL.Terminals.raw!(terminal, false) + return nothing + elseif char == 'r' + RLBase.reset!(env) + elseif char in action_chars + env(char_to_action[char]) + else + @warn "No procedure exists for this character: $char" + end + + Play.write_io1_maybe_io2(terminal_out, file, "Last character = $(char)\n") + end + finally + Play.write_io1_maybe_io2(terminal_out, file, Play.SHOW_CURSOR) + Play.close_maybe(file) + REPL.Terminals.raw!(terminal, false) + end + + return nothing +end + +play!(env::SingleRoomUndirected; file_name = nothing) = play!(REPL.TerminalMenus.terminal, env, file_name = file_name) + +end # module diff --git a/src/envs/single_room_undirected_batch.jl b/src/envs/single_room_undirected_batch.jl new file mode 100644 index 0000000..c7ed553 --- /dev/null +++ b/src/envs/single_room_undirected_batch.jl @@ -0,0 +1,295 @@ +module ModuleSingleRoomUndirectedBatch + +import Crayons +import ..GridWorlds as GW +import ..Play +import Random +import REPL +import ReinforcementLearningBase as RLBase +import StaticArrays as SA +import StatsBase as SB + +const MOVE_UP = 1 +const MOVE_DOWN = 2 +const MOVE_LEFT = 3 +const MOVE_RIGHT = 4 +const ACTION_NAMES = (:MOVE_UP, :MOVE_DOWN, :MOVE_LEFT, :MOVE_RIGHT) + +const AGENT = 1 +const WALL = 2 +const GOAL = 3 + +const DUMMY_CHARACTER = '⋅' +const CHARACTERS = ('☻', '█', '♥') +const FOREGROUND_COLORS = (:light_red, :white, :light_red) + +function sample_two_positions_without_replacement(rng, region) + position1 = rand(rng, region) + position2 = rand(rng, region) + + while position1 == position2 + position2 = rand(rng, region) + end + + return position1, position2 +end + +function move(action::Integer, i, j) + if action == MOVE_UP + return i - 1, j + elseif action == MOVE_DOWN + return i + 1, j + elseif action == MOVE_LEFT + return i, j - 1 + elseif action == MOVE_RIGHT + return i, j + 1 + else + return i, j + end +end + +struct SingleRoomUndirectedBatch{I, R, RNG} <: GW.AbstractGridWorld + tile_map::BitArray{4} + agent_position::Array{I, 2} + reward::Array{R, 1} + rng::Array{RNG, 1} + done::BitArray{1} + terminal_reward::R + goal_position::Array{I, 2} +end + +function SingleRoomUndirectedBatch(; I = Int32, R = Float32, num_envs = 2, height = 8, width = 8, rng = [Random.MersenneTwister() for i in 1:num_envs]) + tile_map = BitArray(undef, 3, height, width, num_envs) + agent_position = Array{I}(undef, 2, num_envs) + reward = Array{R}(undef, num_envs) + done = BitArray(undef, num_envs) + goal_position = Array{I}(undef, 2, num_envs) + terminal_reward = one(R) + + inner_area = CartesianIndices((2 : height - 1, 2 : width - 1)) + + for env_id in 1:num_envs + tile_map[:, :, :, env_id] .= false + tile_map[WALL, 1, :, env_id] .= true + tile_map[WALL, height, :, env_id] .= true + tile_map[WALL, :, 1, env_id] .= true + tile_map[WALL, :, width, env_id] .= true + + random_positions = sample_two_positions_without_replacement(rng[env_id], inner_area) + + agent_position[1, env_id] = random_positions[1][1] + agent_position[2, env_id] = random_positions[1][2] + tile_map[AGENT, random_positions[1], env_id] = true + + goal_position[1, env_id] = random_positions[2][1] + goal_position[2, env_id] = random_positions[2][2] + tile_map[GOAL, random_positions[2], env_id] = true + + reward[env_id] = zero(R) + done[env_id] = false + end + + env = SingleRoomUndirectedBatch(tile_map, agent_position, reward, rng, done, terminal_reward, goal_position) + + RLBase.reset!(env, force = true) + + return env +end + +RLBase.StateStyle(env::SingleRoomUndirectedBatch) = RLBase.InternalState{Any}() +RLBase.state_space(env::SingleRoomUndirectedBatch, ::RLBase.InternalState) = nothing +RLBase.state(env::SingleRoomUndirectedBatch, ::RLBase.InternalState) = env.tile_map + +RLBase.action_space(env::SingleRoomUndirectedBatch) = (MOVE_UP, MOVE_DOWN, MOVE_LEFT, MOVE_RIGHT) +RLBase.reward(env::SingleRoomUndirectedBatch) = env.reward +RLBase.is_terminated(env::SingleRoomUndirectedBatch) = env.done + +function RLBase.reset!(env::SingleRoomUndirectedBatch{I, R}; force = false) where {I, R} + tile_map = env.tile_map + agent_position = env.agent_position + goal_position = env.goal_position + reward = env.reward + done = env.done + rng = env.rng + + num_objects, height, width, num_envs = size(tile_map) + inner_area = CartesianIndices((2 : height - 1, 2 : width - 1)) + + for env_id in 1:num_envs + if force || done[env_id] + tile_map[AGENT, agent_position[1, env_id], agent_position[2, env_id], env_id] = false + tile_map[GOAL, goal_position[1, env_id], goal_position[2, env_id], env_id] = false + + random_positions = sample_two_positions_without_replacement(rng[env_id], inner_area) + + agent_position[1, env_id] = random_positions[1][1] + agent_position[2, env_id] = random_positions[1][2] + tile_map[AGENT, random_positions[1], env_id] = true + + goal_position[1, env_id] = random_positions[2][1] + goal_position[2, env_id] = random_positions[2][2] + tile_map[GOAL, random_positions[2], env_id] = true + + reward[env_id] = zero(R) + done[env_id] = false + end + end + + return nothing +end + +function (env::SingleRoomUndirectedBatch{I, R})(action::Vector) where {I, R} + tile_map = env.tile_map + agent_position = env.agent_position + goal_position = env.goal_position + reward = env.reward + done = env.done + rng = env.rng + terminal_reward = env.terminal_reward + + num_envs = size(tile_map, 4) + + for env_id in 1:num_envs + current_position_i = agent_position[1, env_id] + current_position_j = agent_position[2, env_id] + next_position_i, next_position_j = move(action[env_id], current_position_i, current_position_j) + + if !tile_map[WALL, next_position_i, next_position_j, env_id] + tile_map[AGENT, current_position_i, current_position_j, env_id] = false + agent_position[1, env_id] = next_position_i + agent_position[2, env_id] = next_position_j + tile_map[AGENT, next_position_i, next_position_j, env_id] = true + end + + new_current_position_i = agent_position[1, env_id] + new_current_position_j = agent_position[2, env_id] + + if tile_map[GOAL, new_current_position_i, new_current_position_j, env_id] + done[env_id] = true + reward[env_id] = terminal_reward + else + done[env_id] = false + reward[env_id] = zero(R) + end + end + + return nothing +end + +function Base.show(io::IO, ::MIME"text/plain", env::SingleRoomUndirectedBatch) + tile_map = env.tile_map + reward = env.reward + done = env.done + + num_objects, height, width, num_envs = size(tile_map) + + print(io, "objects = ") + for i in 1 : length(CHARACTERS) + print(io, Crayons.Crayon(foreground = FOREGROUND_COLORS[i]), CHARACTERS[i], Crayons.Crayon(reset = true)) + if i < length(CHARACTERS) + print(io, ", ") + else + print(io, "\n") + end + end + println(io, "dummy character = ", DUMMY_CHARACTER) + + for env_id in 1:num_envs + println(io) + println(io, "env_id = ", env_id) + for i in 1:height + for j in 1:width + idx = findfirst(@view tile_map[:, i, j, env_id]) + if isnothing(idx) + print(io, DUMMY_CHARACTER) + else + print(io, Crayons.Crayon(foreground = FOREGROUND_COLORS[idx]), CHARACTERS[idx], Crayons.Crayon(reset = true)) + end + end + + println(io) + end + + println(io, "reward = ", reward[env_id]) + println(io, "done = ", done[env_id]) + end + + return nothing +end + +get_string_key_bindings(env::GW.AbstractGridWorld) = """Key bindings: + 'q': quit + 'r': RLBase.reset!(env) + 'w': MOVE_UP + 's': MOVE_DOWN + 'a': MOVE_LEFT + 'd': MOVE_RIGHT + """ + +function play!(terminal::REPL.Terminals.UnixTerminal, env::SingleRoomUndirectedBatch; file_name::Union{Nothing, AbstractString} = nothing) + REPL.Terminals.raw!(terminal, true) + + terminal_out = terminal.out_stream + terminal_in = terminal.in_stream + file = Play.open_maybe(file_name) + + Play.write_io1_maybe_io2(terminal_out, file, Play.CLEAR_SCREEN) + Play.write_io1_maybe_io2(terminal_out, file, Play.MOVE_CURSOR_TO_ORIGIN) + Play.write_io1_maybe_io2(terminal_out, file, Play.HIDE_CURSOR) + + num_envs = size(env.tile_map, 4) + chars = Array{Char}(undef, num_envs) + + action_chars = ('w', 's', 'a', 'd') + + char_to_action = Dict('w' => MOVE_UP, + 's' => MOVE_DOWN, + 'a' => MOVE_LEFT, + 'd' => MOVE_RIGHT, + ) + + action = Array{Int}(undef, num_envs) + + try + while true + Play.write_io1_maybe_io2(terminal_out, file, get_string_key_bindings(env)) + Play.show_io1_maybe_io2(terminal_out, file, MIME("text/plain"), env) + + for i in 1:num_envs + c = read(terminal_in, Char) + chars[i] = c + write(terminal_out, c) + end + + Play.write_io1_maybe_io2(terminal_out, file, Play.EMPTY_SCREEN) + + if 'q' in chars + Play.write_io1_maybe_io2(terminal_out, file, Play.SHOW_CURSOR) + Play.close_maybe(file) + REPL.Terminals.raw!(terminal, false) + return nothing + elseif 'r' in chars + RLBase.reset!(env) + elseif all(char -> char in action_chars, chars) + for i in 1:num_envs + action[i] = char_to_action[chars[i]] + end + env(action) + else + @warn "No procedure exists for this character sequence: $chars" + end + + Play.write_io1_maybe_io2(terminal_out, file, "Last character sequence = $(chars)\n") + end + finally + Play.write_io1_maybe_io2(terminal_out, file, Play.SHOW_CURSOR) + Play.close_maybe(file) + REPL.Terminals.raw!(terminal, false) + end + + return nothing +end + +play!(env::SingleRoomUndirectedBatch; file_name = nothing) = play!(REPL.TerminalMenus.terminal, env, file_name = file_name) + +end # module diff --git a/src/play.jl b/src/play.jl new file mode 100644 index 0000000..5096f3f --- /dev/null +++ b/src/play.jl @@ -0,0 +1,45 @@ +module Play + +import REPL + +const ESC = Char(0x1B) +const HIDE_CURSOR = ESC * "[?25l" +const SHOW_CURSOR = ESC * "[?25h" +const CLEAR_SCREEN = ESC * "[2J" +const MOVE_CURSOR_TO_ORIGIN = ESC * "[H" +const CLEAR_SCREEN_BEFORE_CURSOR = ESC * "[1J" +const EMPTY_SCREEN = CLEAR_SCREEN_BEFORE_CURSOR * MOVE_CURSOR_TO_ORIGIN + +open_maybe(file_name::AbstractString) = open(file_name, "w") +open_maybe(::Nothing) = nothing + +close_maybe(io::IO) = close(io) +close_maybe(io::Nothing) = nothing + +write_maybe(io::IO, content) = write(io, content) +write_maybe(io::Nothing, content) = 0 +write_io1_maybe_io2(io1::IO, io2::Union{Nothing, IO}, content) = write(io1, content) + write_maybe(io2, content) + +show_maybe(io::IO, mime::MIME, content) = show(io, mime, content) +show_maybe(io::Nothing, mime::MIME, content) = nothing +function show_io1_maybe_io2(io1::IO, io2::Union{Nothing, IO}, mime::MIME, content) + show(io1, mime, content) + show_maybe(io2, mime, content) +end + +function replay(terminal::REPL.Terminals.UnixTerminal, file_name::AbstractString, frame_rate) + terminal_out = terminal.out_stream + delimiter = EMPTY_SCREEN + frames = split(read(file_name, String), delimiter) + for frame in frames + write(terminal_out, frame) + sleep(1 / frame_rate) + write(terminal_out, delimiter) + end + + return nothing +end + +replay(file_name; frame_rate = 2) = replay(REPL.TerminalMenus.terminal, file_name, frame_rate) + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index b6d8879..6419b15 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,6 +30,8 @@ ENVS = [GW.EmptyRoomDirected, GW.CollectGemsUndirectedMultiAgent, ] +BATCH_ENVS = [GW.ModuleSingleRoomUndirectedBatch.SingleRoomUndirectedBatch] + const MAX_STEPS = 3000 const NUM_RESETS = 3 @@ -60,45 +62,90 @@ get_terminal_returns(env::GW.Catcher) = env.terminal_reward:env.ball_reward:MAX_ get_terminal_returns(env::GW.TransportDirected) = (GW.get_terminal_reward(env),) get_terminal_returns(env::GW.TransportUndirected) = (GW.get_terminal_reward(env),) +get_terminal_returns(env::GW.ModuleSingleRoomUndirectedBatch.SingleRoomUndirectedBatch) = (env.terminal_reward,) + Test.@testset "GridWorlds.jl" begin - for Env in ENVS - Test.@testset "$(Env)" begin - T = Float32 - env = Env(T = T) - for _ in 1:NUM_RESETS - RLBase.reset!(env) - Test.@test RLBase.reward(env) == zero(T) - Test.@test RLBase.is_terminated(env) == false - - total_reward = zero(T) - for i in 1:MAX_STEPS - action = rand(RLBase.action_space(env)) - env(action) - total_reward += RLBase.reward(env) - - - if Env == GW.CollectGemsUndirectedMultiAgent - for i in 1:GW.get_num_agents(env) - agent_pos = env.agent_pos[i] - Test.@test 1 ≤ agent_pos[1] ≤ GW.get_height(env) - Test.@test 1 ≤ agent_pos[2] ≤ GW.get_width(env) - end - else - Test.@test 1 ≤ GW.get_agent_pos(env)[1] ≤ GW.get_height(env) - Test.@test 1 ≤ GW.get_agent_pos(env)[2] ≤ GW.get_width(env) - end + Test.@testset "Single Environments" begin + for Env in ENVS + Test.@testset "$(Env)" begin + T = Float32 + env = Env(T = T) + for _ in 1:NUM_RESETS + RLBase.reset!(env) + Test.@test RLBase.reward(env) == zero(T) + Test.@test RLBase.is_terminated(env) == false + + total_reward = zero(T) + for i in 1:MAX_STEPS + action = rand(RLBase.action_space(env)) + env(action) + total_reward += RLBase.reward(env) + - if RLBase.is_terminated(env) - if Env == GW.Snake - Test.@test (total_reward in get_terminal_returns_win(env) || total_reward in get_terminal_returns_lose(env)) + if Env == GW.CollectGemsUndirectedMultiAgent + for i in 1:GW.get_num_agents(env) + agent_pos = env.agent_pos[i] + Test.@test 1 ≤ agent_pos[1] ≤ GW.get_height(env) + Test.@test 1 ≤ agent_pos[2] ≤ GW.get_width(env) + end else - Test.@test total_reward in get_terminal_returns(env) + Test.@test 1 ≤ GW.get_agent_pos(env)[1] ≤ GW.get_height(env) + Test.@test 1 ≤ GW.get_agent_pos(env)[2] ≤ GW.get_width(env) + end + + if RLBase.is_terminated(env) + if Env == GW.Snake + Test.@test (total_reward in get_terminal_returns_win(env) || total_reward in get_terminal_returns_lose(env)) + else + Test.@test total_reward in get_terminal_returns(env) + end + break + end + + if i == MAX_STEPS + @info "$Env not terminated after MAX_STEPS = $MAX_STEPS" end - break end + end + end + end + end + + Test.@testset "Batch Environments" begin + for Env in BATCH_ENVS + Test.@testset "$(Env)" begin + num_envs = 2 + R = Float32 + I = Int32 + env = Env(I = I, R = R, num_envs = num_envs) + height = size(env.tile_map, 2) + width = size(env.tile_map, 3) + for _ in 1:NUM_RESETS + RLBase.reset!(env) + Test.@test RLBase.reward(env) == zeros(R, num_envs) + Test.@test RLBase.is_terminated(env) == falses(num_envs) + + total_reward = zeros(R, num_envs) + for i in 1:MAX_STEPS + action = [rand(RLBase.action_space(env)) for _ in 1:num_envs] + env(action) + total_reward .+= RLBase.reward(env) - if i == MAX_STEPS - @info "$Env not terminated after MAX_STEPS = $MAX_STEPS" + for env_id in 1:num_envs + Test.@test 1 ≤ env.agent_position[1, env_id] ≤ height + Test.@test 1 ≤ env.agent_position[2, env_id] ≤ width + end + + for env_id in 1:num_envs + if RLBase.is_terminated(env)[env_id] + Test.@test total_reward[env_id] in get_terminal_returns(env) + total_reward[env_id] = zero(total_reward[env_id]) + end + end + + if i == MAX_STEPS && !any(RLBase.is_terminated(env)) + @info "$Env not terminated after MAX_STEPS = $MAX_STEPS" + end end end end