|
| 1 | +# --- |
| 2 | +# title: JuliaRL\_IDQN\_TicTacToe |
| 3 | +# cover: |
| 4 | +# description: IDQN applied to TicTacToe competitive |
| 5 | +# date: 2023-07-03 |
| 6 | +# author: "[Panajiotis Keßler](mailto:panajiotis.kessler@gmail.com)" |
| 7 | +# --- |
| 8 | + |
| 9 | +using StableRNGs |
| 10 | +using ReinforcementLearning |
| 11 | +using ReinforcementLearningBase |
| 12 | +using ReinforcementLearningZoo |
| 13 | +using ReinforcementLearningCore |
| 14 | +using Plots |
| 15 | +using Flux |
| 16 | +using Flux.Losses: huber_loss |
| 17 | +using Flux: glorot_uniform |
| 18 | + |
| 19 | +using ProgressMeter |
| 20 | + |
| 21 | + |
| 22 | +rng = StableRNG(1234) |
| 23 | + |
| 24 | +cap = 100 |
| 25 | + |
| 26 | +RLCore.forward(L::DQNLearner, state::A) where {A <: Real} = RLCore.forward(L, [state]) |
| 27 | + |
| 28 | +create_policy() = QBasedPolicy( |
| 29 | + learner=DQNLearner( |
| 30 | + approximator=Approximator( |
| 31 | + model=TwinNetwork( |
| 32 | + Chain( |
| 33 | + Dense(1, 512, relu; init=glorot_uniform(rng)), |
| 34 | + Dense(512, 256, relu; init=glorot_uniform(rng)), |
| 35 | + Dense(256, 9; init=glorot_uniform(rng)), |
| 36 | + ); |
| 37 | + sync_freq=100 |
| 38 | + ), |
| 39 | + optimiser=ADAM(), |
| 40 | + ), |
| 41 | + n=32, |
| 42 | + γ=0.99f0, |
| 43 | + is_enable_double_DQN=true, |
| 44 | + loss_func=huber_loss, |
| 45 | + rng=rng, |
| 46 | + ), |
| 47 | + explorer=EpsilonGreedyExplorer( |
| 48 | + kind=:exp, |
| 49 | + ϵ_stable=0.01, |
| 50 | + decay_steps=500, |
| 51 | + rng=rng, |
| 52 | + ), |
| 53 | + ) |
| 54 | + |
| 55 | +e = TicTacToeEnv(); |
| 56 | +m = MultiAgentPolicy(NamedTuple((player => |
| 57 | + Agent(player != :Cross ? create_policy() : RandomPolicy(;rng=rng), |
| 58 | + Trajectory( |
| 59 | + container=CircularArraySARTTraces( |
| 60 | + capacity=cap, |
| 61 | + state=Integer => (1,), |
| 62 | + ), |
| 63 | + sampler=NStepBatchSampler{SS′ART}( |
| 64 | + n=1, |
| 65 | + γ=0.99f0, |
| 66 | + batch_size=1, |
| 67 | + rng=rng |
| 68 | + ), |
| 69 | + controller=InsertSampleRatioController( |
| 70 | + threshold=1, |
| 71 | + n_inserted=0 |
| 72 | + )) |
| 73 | + ) |
| 74 | + for player in players(e))) |
| 75 | + ); |
| 76 | +hooks = MultiAgentHook(NamedTuple((p => TotalRewardPerEpisode() for p ∈ players(e)))) |
| 77 | + |
| 78 | +episodes_per_step = 25 |
| 79 | +win_rates = (Cross=Float64[], Nought=Float64[]) |
| 80 | +@showprogress for i ∈ 1:2 |
| 81 | + run(m, e, StopAfterEpisode(episodes_per_step; is_show_progress=false), hooks) |
| 82 | + wr_cross = sum(hooks[:Cross].rewards)/(i*episodes_per_step) |
| 83 | + wr_nought = sum(hooks[:Nought].rewards)/(i*episodes_per_step) |
| 84 | + push!(win_rates[:Cross], wr_cross) |
| 85 | + push!(win_rates[:Nought], wr_nought) |
| 86 | +end |
| 87 | +p1 = plot([win_rates[:Cross] win_rates[:Nought]], labels=["Cross" "Nought"]) |
| 88 | +xlabel!("Iteration steps of $episodes_per_step episodes") |
| 89 | +ylabel!("Win rate of the player") |
| 90 | + |
| 91 | +p2 = plot([hooks[:Cross].rewards hooks[:Nought].rewards], labels=["Cross" "Nought"]) |
| 92 | +xlabel!("Overall episodes") |
| 93 | +ylabel!("Rewards of the players") |
| 94 | + |
| 95 | +p = plot(p1, p2, layout=(2,1), size=[1000,1000]) |
| 96 | +savefig("TTT_CROSS_DQN_NOUGHT_RANDOM.png") |
0 commit comments