Skip to content

Commit c70f4f0

Browse files
authored
enable OpenSpiel (#691)
* enable OpenSpiel * passCI
1 parent d2bfd1f commit c70f4f0

File tree

4 files changed

+21
-25
lines changed

4 files changed

+21
-25
lines changed

src/ReinforcementLearningEnvironments/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ julia = "1.3"
2727

2828
[extras]
2929
ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977"
30+
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
3031
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3132
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
3233
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3334
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3435
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3536

3637
[targets]
37-
test = ["ArcadeLearningEnvironment", "OrdinaryDiffEq", "PyCall", "StableRNGs", "Statistics", "Test"]
38+
test = ["ArcadeLearningEnvironment", "OpenSpiel", "OrdinaryDiffEq", "PyCall", "StableRNGs", "Statistics", "Test"]

src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
6060
RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
6161

6262
function RLBase.players(env::OpenSpielEnv)
63-
p = 0:(num_players(env.game) - 1)
63+
p = 0:(num_players(env.game)-1)
6464
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
6565
(p..., RLBase.chance_player(env))
6666
else
@@ -73,9 +73,9 @@ function RLBase.action_space(env::OpenSpielEnv, player)
7373
# !!! this bug is already fixed in OpenSpiel
7474
# replace it with the following one later
7575
# ZeroTo(max_chance_outcomes(env.game)-1)
76-
ZeroTo(max_chance_outcomes(env.game))
76+
Space(0:max_chance_outcomes(env.game))
7777
else
78-
ZeroTo(num_distinct_actions(env.game) - 1)
78+
Space(0:num_distinct_actions(env.game)-1)
7979
end
8080
end
8181

@@ -91,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player)
9191
# @assert player == chance_player(env)
9292
p = zeros(length(action_space(env)))
9393
for (k, v) in chance_outcomes(env.state)
94-
p[k + 1] = v
94+
p[k+1] = v
9595
end
9696
p
9797
end
@@ -102,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player)
102102
num_distinct_actions(env.game)
103103
mask = BitArray(undef, n)
104104
for a in legal_actions(env.state, player)
105-
mask[a + 1] = true
105+
mask[a+1] = true
106106
end
107107
mask
108108
end
@@ -126,7 +126,7 @@ function RLBase.state(env::OpenSpielEnv, ss::RLBase.AbstractStateStyle, player)
126126
if player < 0 # TODO: revisit this in OpenSpiel@v0.2
127127
@warn "unexpected player $player, falling back to default state value." maxlog = 1
128128
s = state_space(env)
129-
if s isa WorldSpace
129+
if s === Space(AbstractString)
130130
""
131131
elseif s isa Array{<:Interval}
132132
rand(s)
@@ -149,19 +149,15 @@ RLBase.state_space(
149149
env::OpenSpielEnv,
150150
::Union{InformationSet{String},Observation{String}},
151151
p,
152-
) = WorldSpace{AbstractString}()
152+
) = Space(AbstractString)
153153

154154
RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array},
155155
p,
156-
) = Space(
157-
fill(typemin(Float64)..typemax(Float64), reverse(information_state_tensor_shape(env.game))...),
158-
)
156+
) = Space(Float64, reverse(information_state_tensor_shape(env.game))...)
159157

160158
RLBase.state_space(env::OpenSpielEnv, ::Observation{Array},
161159
p,
162-
) = Space(
163-
fill(typemin(Float64)..typemax(Float64), reverse(observation_tensor_shape(env.game))...),
164-
)
160+
) = Space(Float64, reverse(observation_tensor_shape(env.game))...)
165161

166162
Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently."
167163

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
@testset "OpenSpielEnv" begin
2-
3-
# for name in [
4-
# "tic_tac_toe",
5-
# "kuhn_poker",
6-
# "goofspiel(imp_info=True,num_cards=4,points_order=descending)",
7-
# ]
8-
# @info "testing OpenSpiel: $name"
9-
# env = OpenSpielEnv(name)
10-
# RLBase.test_runnable!(env)
11-
# end
2+
for name in [
3+
"tic_tac_toe",
4+
"kuhn_poker",
5+
"goofspiel(imp_info=True,num_cards=4,points_order=descending)",
6+
]
7+
@info "testing OpenSpiel: $name"
8+
env = OpenSpielEnv(name)
9+
RLBase.test_runnable!(env)
10+
end
1211
end

src/ReinforcementLearningEnvironments/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using ReinforcementLearningBase
33
using ReinforcementLearningEnvironments
44
using ArcadeLearningEnvironment
55
using PyCall
6-
# using OpenSpiel
6+
using OpenSpiel
77
# using SnakeGames
88
using Random
99
using StableRNGs

0 commit comments

Comments
 (0)