Skip to content

Commit 82f77f4

Browse files
committed
added experiments to test
1 parent efc7f7b commit 82f77f4

File tree

6 files changed

+170
-7
lines changed

6 files changed

+170
-7
lines changed

src/ReinforcementLearningExperiments/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ version = "0.3.1"
66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1112
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1213
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
1314
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
1415
ReinforcementLearningZoo = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
16+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1517
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1618
Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9"
1719

@@ -29,7 +31,9 @@ julia = "1.9"
2931

3032
[extras]
3133
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
35+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
3236
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3337

3438
[targets]
35-
test = ["CUDA", "Test"]
39+
test = ["CUDA", "PyCall", "Test"]

src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl renamed to src/ReinforcementLearningExperiments/deps/experiments/experiments/MARL/DQN_mpe_simple.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using Flux.Losses: huber_loss
1717
function RLCore.Experiment(
1818
::Val{:JuliaRL},
1919
::Val{:DQN},
20-
::Val{:MPESimple};
20+
::Val{:MPESimple},
2121
seed=123,
2222
n=1,
2323
γ=0.99f0,

src/ReinforcementLearningExperiments/src/experiments/MARL/IDQN_TicTacToe.jl renamed to src/ReinforcementLearningExperiments/deps/experiments/experiments/MARL/IDQN_TicTacToe.jl

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# ---
88

99
using StableRNGs
10-
using ReinforcementLearning
1110
using ReinforcementLearningBase
1211
using ReinforcementLearningZoo
1312
using ReinforcementLearningCore
@@ -16,7 +15,6 @@ using Flux
1615
using Flux.Losses: huber_loss
1716
using Flux: glorot_uniform
1817

19-
using ProgressMeter
2018

2119

2220
rng = StableRNG(1234)
@@ -25,6 +23,71 @@ cap = 100
2523

2624
RLCore.forward(L::DQNLearner, state::A) where {A <: Real} = RLCore.forward(L, [state])
2725

26+
27+
episodes_per_step = 25
28+
29+
function RLCore.Experiment(
30+
::Val{:JuliaRL},
31+
::Val{:IDQN},
32+
::Val{:TicTacToe},
33+
seed=123,
34+
n=1,
35+
γ=0.99f0,
36+
is_enable_double_DQN=true
37+
)
38+
rng = StableRNG(seed)
39+
create_policy() = QBasedPolicy(
40+
learner=DQNLearner(
41+
approximator=Approximator(
42+
model=TwinNetwork(
43+
Chain(
44+
Dense(1, 512, relu; init=glorot_uniform(rng)),
45+
Dense(512, 256, relu; init=glorot_uniform(rng)),
46+
Dense(256, 9; init=glorot_uniform(rng)),
47+
);
48+
sync_freq=100
49+
),
50+
optimiser=Adam(),
51+
),
52+
n=n,
53+
γ=γ,
54+
is_enable_double_DQN=is_enable_double_DQN,
55+
loss_func=huber_loss,
56+
rng=rng,
57+
),
58+
explorer=EpsilonGreedyExplorer(
59+
kind=:exp,
60+
ϵ_stable=0.01,
61+
decay_steps=500,
62+
rng=rng,
63+
),
64+
)
65+
66+
e = TicTacToeEnv();
67+
m = MultiAgentPolicy(NamedTuple((player =>
68+
Agent(player != :Cross ? create_policy() : RandomPolicy(;rng=rng),
69+
Trajectory(
70+
container=CircularArraySARTTraces(
71+
capacity=cap,
72+
state=Integer => (1,),
73+
),
74+
sampler=NStepBatchSampler{SS′ART}(
75+
n=n,
76+
γ=γ,
77+
batch_size=1,
78+
rng=rng
79+
),
80+
controller=InsertSampleRatioController(
81+
threshold=1,
82+
n_inserted=0
83+
))
84+
)
85+
for player in players(e)))
86+
);
87+
hooks = MultiAgentHook(NamedTuple((p => TotalRewardPerEpisode() for p players(e))))
88+
Experiment(m, e, StopAfterEpisode(episodes_per_step), hooks)
89+
end
90+
2891
create_policy() = QBasedPolicy(
2992
learner=DQNLearner(
3093
approximator=Approximator(
@@ -36,7 +99,7 @@ create_policy() = QBasedPolicy(
3699
);
37100
sync_freq=100
38101
),
39-
optimiser=ADAM(),
102+
optimiser=Adam(),
40103
),
41104
n=32,
42105
γ=0.99f0,
@@ -75,9 +138,8 @@ m = MultiAgentPolicy(NamedTuple((player =>
75138
);
76139
hooks = MultiAgentHook(NamedTuple((p => TotalRewardPerEpisode() for p players(e))))
77140

78-
episodes_per_step = 25
79141
win_rates = (Cross=Float64[], Nought=Float64[])
80-
@showprogress for i 1:2
142+
for i 1:2
81143
run(m, e, StopAfterEpisode(episodes_per_step; is_show_progress=false), hooks)
82144
wr_cross = sum(hooks[:Cross].rewards)/(i*episodes_per_step)
83145
wr_nought = sum(hooks[:Nought].rewards)/(i*episodes_per_step)

src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_Rainbow_CartPole.jl"))
1919
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_VPG_CartPole.jl"))
2020
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_TRPO_CartPole.jl"))
2121
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_MPO_CartPole.jl"))
22+
include(joinpath(EXPERIMENTS_DIR, "IDQN_TicTacToe.jl"))
2223

2324
# dynamic loading environments
2425
function __init__() end
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
2+
# ---
3+
# title: JuliaRL\_DQN\_MPESimple
4+
# cover:
5+
# description: DQN applied to MPE simple
6+
# date: 2023-02-01
7+
# author: "[Panajiotis Keßler](mailto:panajiotis@christoforidis.net)"
8+
# ---
9+
10+
using PyCall
11+
using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo
12+
using Flux
13+
using Flux: glorot_uniform
14+
15+
using StableRNGs: StableRNG
16+
using Flux.Losses: huber_loss
17+
18+
function RLCore.Experiment(
19+
::Val{:JuliaRL},
20+
::Val{:DQN},
21+
::Val{:MPESimple},
22+
seed=123,
23+
n=1,
24+
γ=0.99f0,
25+
is_enable_double_DQN=true
26+
)
27+
rng = StableRNG(seed)
28+
env = discrete2standard_discrete(PettingZooEnv("mpe.simple_v2"; seed=seed))
29+
ns, na = length(state(env)), length(action_space(env))
30+
31+
agent = Agent(
32+
policy=QBasedPolicy(
33+
learner=DQNLearner(
34+
approximator=Approximator(
35+
model=TwinNetwork(
36+
Chain(
37+
Dense(ns, 128, relu; init=glorot_uniform(rng)),
38+
Dense(128, 128, relu; init=glorot_uniform(rng)),
39+
Dense(128, na; init=glorot_uniform(rng)),
40+
);
41+
sync_freq=100
42+
),
43+
optimiser=Adam(),
44+
),
45+
n=n,
46+
γ=γ,
47+
is_enable_double_DQN=is_enable_double_DQN,
48+
loss_func=huber_loss,
49+
rng=rng,
50+
),
51+
explorer=EpsilonGreedyExplorer(
52+
kind=:exp,
53+
ϵ_stable=0.01,
54+
decay_steps=500,
55+
rng=rng,
56+
),
57+
),
58+
trajectory=Trajectory(
59+
container=CircularArraySARTTraces(
60+
capacity=1000,
61+
state=Float32 => (ns,),
62+
),
63+
sampler=NStepBatchSampler{SS′ART}(
64+
n=n,
65+
γ=γ,
66+
batch_size=32,
67+
rng=rng
68+
),
69+
controller=InsertSampleRatioController(
70+
threshold=100,
71+
n_inserted=-1
72+
)
73+
)
74+
)
75+
76+
stop_condition = StopAfterEpisode(150, is_show_progress=!haskey(ENV, "CI"))
77+
hook = TotalRewardPerEpisode()
78+
Experiment(agent, env, stop_condition, hook)
79+
end
80+
81+
using Plots
82+
ex = E`JuliaRL_DQN_MPESimple`
83+
run(ex)
84+
plot(ex.hook.rewards)
85+
savefig("JuliaRL_DQN_MPESimple.png")
86+

src/ReinforcementLearningExperiments/test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
using ReinforcementLearningExperiments
22
using CUDA
33

4+
using Requires
5+
6+
const EXPERIMENTS_DIR = joinpath(@__DIR__, "experiments")
7+
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" include(
8+
joinpath(EXPERIMENTS_DIR, "DQN_mpe_simple.jl")
9+
)
10+
11+
412
CUDA.allowscalar(false)
513

614
run(E`JuliaRL_NFQ_CartPole`)
@@ -15,6 +23,8 @@ run(E`JuliaRL_VPG_CartPole`)
1523
run(E`JuliaRL_MPODiscrete_CartPole`)
1624
run(E`JuliaRL_MPOContinuous_CartPole`)
1725
run(E`JuliaRL_MPOCovariance_CartPole`)
26+
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" run(E`JuliaRL_DQN_MPESimple`)
27+
run(E`JuliaRL_IDQN_TicTacToe`)
1828
# run(E`JuliaRL_BC_CartPole`)
1929
# run(E`JuliaRL_VMPO_CartPole`)
2030
# run(E`JuliaRL_BasicDQN_MountainCar`)

0 commit comments

Comments
 (0)