Skip to content

Commit 77b3376

Browse files
authored
Merge pull request #695 from JuliaReinforcementLearning/HenriDeh-patch-2
Update the "how to implement a new algorithm"
2 parents 9bfe4cf + 84941ee commit 77b3376

File tree

1 file changed

+39
-36
lines changed

1 file changed

+39
-36
lines changed
Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,88 @@
11
# How to implement a new algorithm
22

3-
All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `policy(stage, env [,action])`.
3+
All algorithms in ReinforcementLearning.jl are based on a common `run` function defined in [run.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/core/run.jl) that will be dispatched based on the type of its arguments. As you can see, the run function first performs a check and then calls a "private" `_run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)`, this is the main function we are interested in. It consists of an outer and an inner loop that will repeateadly call `policy(stage, env)`.
44

55
Let's look at it closer in this simplified version (hooks are discussed [here](./How_to_use_hooks.md)):
66

77
```julia
88
function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook::AbstractHook)
99

10-
policy(PRE_EXPERIMENT_STAGE, env)
10+
policy(PreExperimentStage(), env)
1111
is_stop = false
1212
while !is_stop
1313
reset!(env)
14-
policy(PRE_EPISODE_STAGE, env)
14+
policy(PreEpisodeStage(), env)
1515

1616
while !is_terminated(env) # one episode
17-
action = policy(env)
18-
19-
policy(PRE_ACT_STAGE, env, action)
20-
21-
env(action)
22-
23-
policy(POST_ACT_STAGE, env)
24-
17+
policy(PreActStage(), env)
18+
env |> policy |> env
19+
optimise!(policy)
20+
policy(PostActStage(), env)
2521
if stop_condition(policy, env)
22+
policy(PreActStage(), env)
23+
policy(env)
2624
is_stop = true
2725
break
2826
end
2927
end # end of an episode
3028

3129
if is_terminated(env)
32-
policy(POST_EPISODE_STAGE, env)
30+
policy(PostEpisodeStage(), env)
3331
end
3432
end
3533
end
3634
```
3735

38-
Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling function `(policy)(env)` and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).
36+
Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling function `(policy)(env)` and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).
3937

4038
## Using Agents
41-
A better way is to use the policy wrapper `Agent`. An agent is an AbstractPolicy that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `Agent(stage, agent, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/policies/agents/agent.jl/), we can see that the default Agent calls are
39+
A better way is to use the policy wrapper `Agent`. An agent is an AbstractPolicy that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `agent(stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/master/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are
4240

4341
```julia
44-
function (agent::Agent)(stage::AbstractStage, env::AbstractEnv [, action])
45-
update!(agent.trajectory, agent.policy, env, stage [,action])
46-
update!(agent.policy, agent.trajectory, env, stage)
42+
function (agent::Agent)(env::AbstractEnv)
43+
action = agent.policy(env)
44+
push!(agent.trajectory, (agent.cache..., action = action))
45+
agent.cache = (;)
46+
action
4747
end
48-
```
4948

50-
Which consists of updating the trajectory then the policy. `update!(agent.policy, agent.trajectory, env, stage)` is a no-op by default at every stage but `update!(agent.trajectory, agent.policy, env, stage [,action])` comes with predefined updates that can be summarized as follows (keep in mind that trajectories are undergoing major changes and this will soon be outdated.).
49+
(agent::Agent)(::PreActStage, env::AbstractEnv) =
50+
agent.cache = (agent.cache..., state = state(env))
5151

52-
1. At the PRE\_ACT\_STAGE (after having sampled the action), add the current state and the sampled action to the trajectory. If your policy uses an action mask, it will also save it to a respective trace.
53-
2. At the POST\_ACT\_STAGE (after having exerted the action to the environment), add the returned reward and save whether the new state is terminal.
54-
3. At the POST\_EPISODE\_STAGE (before reseting the environment), save the last state to the trajectory.
55-
4. At the PRE\_EPISODE\_STAGE (after reseting the environment), remove the last state from the trajectory.
52+
(agent::Agent)(::PostActStage, env::AbstractEnv) =
53+
agent.cache = (agent.cache..., reward = reward(env), terminal = is_terminated(env))
54+
```
5655

57-
If you need a different behavior for trajectories, then you may overload the `update!` function with your policy type or a custom trajectory type. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the PRE\_ACT\_STAGE.
56+
The default behavior at other stages is a no-op. The first function, `(agent::Agent)(env::AbstractEnv)`, is called at the `env |> policy |> env` line. It gets an action from the policy (since you implemented the `your_new_policy(env)` function), then it pushes its `cache` and the action to the trajectory of the agent. Finally, it empties the cache and returns the action (which is immediately applied to env after). At the `PreActStage()` and the `PostActStage`, the agent simply records the current state of the environment, the returned reward and the terminal state signal to its cache (to be pushed to the trajectory by the first function).
57+
58+
If you need a different behavior at some stages, then you can overload the `(Agent{<:YourPolicyType})([stage,] env)` or (Agent{<:Any, <: YourTrajectoryType})([stage,] env), depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`.
5859

5960
## Updating the policy
6061

61-
Finally, you need to implement the learning function by implementing `(your_policy)( env, stage)` or `update!(your_policy, trajectory, env, stage)`. This is usually done at the PRE\_ACT\_STAGE or the POST\_EPISODE\_STAGE, depending on the algorithm. It is not recommended to do it at other stages because the trajectory will not be consistent and samples from it will be be incorrect.
62+
Finally, you need to implement the learning function by implementing `RLBase.optimise!(p::YourPolicyType, batch::NamedTuple)` (see that it is called by `optimise!(agent)` then `RLBase.optimise!(p::YourPolicyType, b::Trajectory)`).
63+
In principle you can do the update at other stages by overload the `(agent::Agent)` but this is not recommended because the trajectory may not be consistent and samples could be incorrect. If you choose to do it, make sure to know what you are doing.
64+
65+
## ReinforcementLearningTrajectories
66+
67+
Trajectories are handled in a stand-alone package called [ReinforcementLearningTrajectories](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl). Refer to its documentation (in progress) to learn how to use it.
6268

6369
## Using resources from RLCore
6470

65-
### Learners
71+
RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage RLCore contains a lot of utilities that you can reuse to implement your algorithm.
6672

67-
RL algorithms typically differ partially but broadly use the same mechanisms. The subpackage RLCore contains a lot of utilities that you can reuse to implement your algorithm. These are implemented as types that you can impose on certain fields of your own policy type.
73+
The utils folder contains utilities and extensions to external packages to fit needs that are specific to RL.jl. We will not list them all here, but it is a good idea to skim over the files to see what they contain. The policies folder notably contains several explorer implementations. Here are a few interesting examples:
6874

69-
`QBasedPolicy` wraps a policy that relies on a Q-Value _learner_ (tabular or approximated) and an _explorer_ .
75+
- `QBasedPolicy` wraps a policy that relies on a Q-Value _learner_ (tabular or approximated) and an _explorer_ .
7076
RLCore provides several pre-implemented learners and the most common explorers (such as epsilon-greedy, UCB, etc.).
7177

72-
If your algorithm use tabular learners, check out the tabular_learner.jl and the tabular_approximator source files. If your algorithms uses deep neural nets then use the `NeuralNetworkApproximator` to wrap an Neural Network and an optimizer. Common policy architectures are also provided such as the `GaussianNetwork`.
78+
- If your algorithm use tabular learners, check out the tabular_learner.jl and the tabular_approximator source files. If your algorithms uses deep neural nets then use the `NeuralNetworkApproximator` to wrap an Neural Network and an optimizer. Common policy architectures are also provided such as the `GaussianNetwork`.
7379

74-
Equivalently, the `VBasedPolicy` learner is provided for algorithms that use a state-value function. Though they are not bundled in the same folder, most approximators can be used with a VBasedPolicy too.
80+
- Equivalently, the `VBasedPolicy` learner is provided for algorithms that use a state-value function. Though they are not bundled in the same folder, most approximators can be used with a VBasedPolicy too.
7581

7682
<!--- ### Batch samplers
7783
Since this is going to be outdated soon, I'll write this part later on when Trajectories.jl will be done -->
7884

79-
80-
### Extensions
81-
82-
The extensions folder contains extensions to external packages to fit needs that are specific to RL.jl. Notably, in the Distributions.jl you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using Distributions.jl structs.
85+
- In utils/distributions.jl you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using Distributions.jl structs.
8386

8487
## Conventions
8588
Finally, there are a few "conventions" and good practices that you should follow, especially if you intend to contribute to this package (don't worry we'll be happy to help if needed).
@@ -88,9 +91,9 @@ Finally, there are a few "conventions" and good practices that you should follow
8891
ReinforcementLearning.jl aims to provide a framework for reproducible experiments. To do so, make sure that your policy type has a `rng` field and that all random operations (e.g. action sampling or trajectory sampling) use `rand(your_policy.rng, args...)`.
8992

9093
### GPU friendlyness
91-
Deep RL algorithms are often much faster when the neural nets are updated on a GPU. For now, we only support CUDA.jl as a backend. This means that you will have to think about the transfer of data between the CPU (where the trajectory is) and the GPU memory (where the neural nets are). To do so you will find in extensions some functions that do most of the work for you. The ones that you need to know are `send_to_device(device, data)` that sends data to the specified device, `send_to_host(data)` which sends data to the CPU memory (it fallbacks to `send_to_device(Val{:cpu}, data)`) and `device(x)` that returns the device on which `x` is.
94+
Deep RL algorithms are often much faster when the neural nets are updated on a GPU. For now, we only support CUDA.jl as a backend. This means that you will have to think about the transfer of data between the CPU (where the trajectory is) and the GPU memory (where the neural nets are). To do so you will find in utils/device.jl some functions that do most of the work for you. The ones that you need to know are `send_to_device(device, data)` that sends data to the specified device, `send_to_host(data)` which sends data to the CPU memory (it fallbacks to `send_to_device(Val{:cpu}, data)`) and `device(x)` that returns the device on which `x` is.
9295
Normally, you should be able to write a single implementation of your algorithm that works on CPU and GPUs thanks to the multiple dispatch offered by Julia.
9396

94-
GPU friendlyness will also require that your code does not use _scalar indexing_ (see the CUDA.jl documentation for more information), make sure to test your algorithm on the GPU while disallowing scalar indexing using `CUDA.allowscalar(false)`.
97+
GPU friendlyness will also require that your code does not use _scalar indexing_ (see the CUDA.jl documentation for more information), make sure to test your algorithm on the GPU after disallowing scalar indexing by using `CUDA.allowscalar(false)`.
9598

9699
Finally, it is a good idea to implement the `Flux.gpu(yourpolicy)` and `cpu(yourpolicy)` functions, for user convenience. Be careful that sampling on the GPU requires a specific type of rng, you can generate one with `CUDA.default_rng()`

0 commit comments

Comments
 (0)