|
1 | 1 | # How to implement a new algorithm
|
2 | 2 |
|
3 |
| -All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `policy(stage, env [,action])`. |
| 3 | +All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `policy(stage, env)`. |
4 | 4 |
|
5 | 5 | Let's look at it closer in this simplified version (hooks are discussed [here](./How_to_use_hooks.md)):
|
6 | 6 |
|
7 | 7 | ```julia
|
8 | 8 | function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)
|
9 | 9 |
|
10 |
| - policy(PRE_EXPERIMENT_STAGE, env) |
| 10 | + policy(PreExperimentStage(), env) |
11 | 11 | is_stop = false
|
12 | 12 | while !is_stop
|
13 | 13 | reset!(env)
|
14 |
| - policy(PRE_EPISODE_STAGE, env) |
| 14 | + policy(PreEpisodeStage(), env) |
15 | 15 |
|
16 | 16 | while !is_terminated(env) # one episode
|
17 |
| - action = policy(env) |
18 |
| - |
19 |
| - policy(PRE_ACT_STAGE, env, action) |
20 |
| - |
21 |
| - env(action) |
22 |
| - |
23 |
| - policy(POST_ACT_STAGE, env) |
24 |
| - |
| 17 | + policy(PreActStage(), env) |
| 18 | + env |> policy |> env |
| 19 | + optimise!(policy) |
| 20 | + policy(PostActStage(), env) |
25 | 21 | if stop_condition(policy, env)
|
| 22 | + policy(PreActStage(), env) |
| 23 | + policy(env) |
26 | 24 | is_stop = true
|
27 | 25 | break
|
28 | 26 | end
|
29 | 27 | end # end of an episode
|
30 | 28 |
|
31 | 29 | if is_terminated(env)
|
32 |
| - policy(POST_EPISODE_STAGE, env) |
| 30 | + policy(PostEpisodeStage(), env) |
33 | 31 | end
|
34 | 32 | end
|
35 | 33 | end
|
|
38 | 36 | Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling function `(policy)(env)` and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).
|
39 | 37 |
|
40 | 38 | ## Using Agents
|
41 |
| -A better way is to use the policy wrapper `Agent`. An agent is an AbstractPolicy that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `Agent(stage, agent, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/policies/agents/agent.jl/), we can see that the default Agent calls are |
| 39 | +A better way is to use the policy wrapper `Agent`. An agent is an AbstractPolicy that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `agent(stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are |
42 | 40 |
|
43 | 41 | ```julia
|
44 |
| -function (agent::Agent)(stage::AbstractStage, env::AbstractEnv [, action]) |
45 |
| - update!(agent.trajectory, agent.policy, env, stage [,action]) |
46 |
| - update!(agent.policy, agent.trajectory, env, stage) |
| 42 | +function (agent::Agent)(env::AbstractEnv) |
| 43 | + action = agent.policy(env) |
| 44 | + push!(agent.trajectory, (agent.cache..., action = action)) |
| 45 | + agent.cache = (;) |
| 46 | + action |
47 | 47 | end
|
48 |
| -``` |
49 | 48 |
|
50 |
| -Which consists of updating the trajectory then the policy. `update!(agent.policy, agent.trajectory, env, stage)` is a no-op by default at every stage but `update!(agent.trajectory, agent.policy, env, stage [,action])` comes with predefined updates that can be summarized as follows (keep in mind that trajectories are undergoing major changes and this will soon be outdated.). |
| 49 | +(agent::Agent)(::PreActStage, env::AbstractEnv) = |
| 50 | + agent.cache = (agent.cache..., state = state(env)) |
| 51 | + |
| 52 | +(agent::Agent)(::PostActStage, env::AbstractEnv) = |
| 53 | + agent.cache = (agent.cache..., reward = reward(env), terminal = is_terminated(env)) |
| 54 | +``` |
51 | 55 |
|
52 |
| -1. At the PRE\_ACT\_STAGE (after having sampled the action), add the current state and the sampled action to the trajectory. If your policy uses an action mask, it will also save it to a respective trace. |
53 |
| -2. At the POST\_ACT\_STAGE (after having exerted the action to the environment), add the returned reward and save whether the new state is terminal. |
54 |
| -3. At the POST\_EPISODE\_STAGE (before reseting the environment), save the last state to the trajectory. |
55 |
| -4. At the PRE\_EPISODE\_STAGE (after reseting the environment), remove the last state from the trajectory. |
| 56 | +The default behavior at other stages is a no-op. The first function, `(agent::Agent)(env::AbstractEnv)`, is called at the `env |> policy |> env` line. It gets an action from the policy (since you implemented the `your_new_policy(env)` function), then it pushes its `cache` and the action to the trajectory of the agent. Finally, it empties the cache and returns the action (which is immediately applied to env after). At the `PreActStage()` and the `PostActStage`, the agent simply records the current state of the environment, the returned reward and the terminal state signal to its cache (to be pushed to the trajectory by the first function). |
56 | 57 |
|
57 |
| -If you need a different behavior for trajectories, then you may overload the `update!` function with your policy type or a custom trajectory type. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the PRE\_ACT\_STAGE. |
| 58 | +If you need a different behavior at some stages, then you can overload the `(Agent{<:YourPolicyType})([stage,] env)` or (Agent{<:Any, <: YourTrajectoryType})([stage,] env), depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`. |
58 | 59 |
|
59 | 60 | ## Updating the policy
|
60 | 61 |
|
61 |
| -Finally, you need to implement the learning function by implementing `(your_policy)( env, stage)` or `update!(your_policy, trajectory, env, stage)`. This is usually done at the PRE\_ACT\_STAGE or the POST\_EPISODE\_STAGE, depending on the algorithm. It is not recommended to do it at other stages because the trajectory will not be consistent and samples from it will be be incorrect. |
| 62 | +Finally, you need to implement the learning function by implementing `RLBase.optimise!(p::YourPolicyType, batch::NamedTuple)` (see that it is called by `optimise!(agent)` then `RLBase.optimise!(p::YourPolicyType, b::Trajectory)`). |
| 63 | +In principle you can do the update at other stages by overload the `(agent::Agent)` but this is not recommended because the trajectory may not be consistent and samples from could be incorrect. Be sure to know what you are doing. |
62 | 64 |
|
63 | 65 | ## Using resources from RLCore
|
64 | 66 |
|
|
0 commit comments