Skip to content

Commit 5df8e66

Browse files
Panajiotis KeßlerMytolo
authored andcommitted
added first Independent Q Learning experiment
1 parent 1f7f347 commit 5df8e66

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# ---
2+
# title: JuliaRL\_IDQN\_TicTacToe
3+
# cover:
4+
# description: IDQN applied to TicTacToe competitive
5+
# date: 2023-07-03
6+
# author: "[Panajiotis Keßler](mailto:panajiotis.kessler@gmail.com)"
7+
# ---
8+
9+
using StableRNGs
10+
using ReinforcementLearning
11+
using ReinforcementLearningBase
12+
using ReinforcementLearningZoo
13+
using ReinforcementLearningCore
14+
using Plots
15+
using Flux
16+
using Flux.Losses: huber_loss
17+
using Flux: glorot_uniform
18+
19+
using ProgressMeter
20+
21+
22+
rng = StableRNG(1234)
23+
24+
cap = 100
25+
26+
RLCore.forward(L::DQNLearner, state::A) where {A <: Real} = RLCore.forward(L, [state])
27+
28+
create_policy() = QBasedPolicy(
29+
learner=DQNLearner(
30+
approximator=Approximator(
31+
model=TwinNetwork(
32+
Chain(
33+
Dense(1, 512, relu; init=glorot_uniform(rng)),
34+
Dense(512, 256, relu; init=glorot_uniform(rng)),
35+
Dense(256, 9; init=glorot_uniform(rng)),
36+
);
37+
sync_freq=100
38+
),
39+
optimiser=ADAM(),
40+
),
41+
n=32,
42+
γ=0.99f0,
43+
is_enable_double_DQN=true,
44+
loss_func=huber_loss,
45+
rng=rng,
46+
),
47+
explorer=EpsilonGreedyExplorer(
48+
kind=:exp,
49+
ϵ_stable=0.01,
50+
decay_steps=500,
51+
rng=rng,
52+
),
53+
)
54+
55+
e = TicTacToeEnv();
56+
m = MultiAgentPolicy(NamedTuple((player =>
57+
Agent(player != :Cross ? create_policy() : RandomPolicy(;rng=rng),
58+
Trajectory(
59+
container=CircularArraySARTTraces(
60+
capacity=cap,
61+
state=Integer => (1,),
62+
),
63+
sampler=NStepBatchSampler{SS′ART}(
64+
n=1,
65+
γ=0.99f0,
66+
batch_size=1,
67+
rng=rng
68+
),
69+
controller=InsertSampleRatioController(
70+
threshold=1,
71+
n_inserted=0
72+
))
73+
)
74+
for player in players(e)))
75+
);
76+
hooks = MultiAgentHook(NamedTuple((p => TotalRewardPerEpisode() for p players(e))))
77+
78+
episodes_per_step = 25
79+
win_rates = (Cross=Float64[], Nought=Float64[])
80+
@showprogress for i 1:2
81+
run(m, e, StopAfterEpisode(episodes_per_step; is_show_progress=false), hooks)
82+
wr_cross = sum(hooks[:Cross].rewards)/(i*episodes_per_step)
83+
wr_nought = sum(hooks[:Nought].rewards)/(i*episodes_per_step)
84+
push!(win_rates[:Cross], wr_cross)
85+
push!(win_rates[:Nought], wr_nought)
86+
end
87+
p1 = plot([win_rates[:Cross] win_rates[:Nought]], labels=["Cross" "Nought"])
88+
xlabel!("Iteration steps of $episodes_per_step episodes")
89+
ylabel!("Win rate of the player")
90+
91+
p2 = plot([hooks[:Cross].rewards hooks[:Nought].rewards], labels=["Cross" "Nought"])
92+
xlabel!("Overall episodes")
93+
ylabel!("Rewards of the players")
94+
95+
p = plot(p1, p2, layout=(2,1), size=[1000,1000])
96+
savefig("TTT_CROSS_DQN_NOUGHT_RANDOM.png")

0 commit comments

Comments
 (0)