Skip to content

[Feature Request] Support compiling ReplayBuffer.extend/sample without recompile #2501

@kurtamohler

Description

@kurtamohler

Motivation

Compiling a back-to-back call to ReplayBuffer.extend and ReplayBuffer.sample, and then calling it multiple times causes the function to be recompiled each time.

import torch
import torchrl

torch._logging.set_logs(recompiles=True)

rb = torchrl.data.ReplayBuffer(
    storage=torchrl.data.LazyTensorStorage(1000)
)

@torch.compile
def extend_and_sample(data):
    rb.extend(data)
    return rb.sample(2)

for idx in range(15):
    print('---------------------')
    print(f'iteration: {idx}')
    print(f'len: {len(rb.storage)}')
    data = torch.randn(idx + 1, 1)
    extend_and_sample(data)

Running the above script gives the following, showing that the first 9 calls cause recompilations. Then it hits the cache limit, so the calls after that don't get compiled anymore, and it's just running the eager function at that point (per pytorch docs: https://pytorch.org/docs/stable/generated/torch.compile.html).

Click to expand/collapse
---------------------
iteration: 0
len: 0
/home/endoplasm/miniconda/envs/torchrl-0/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:733: UserWarning: Graph break due to unsupported builtin None.SemLock.acquire. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
  torch._dynamo.utils.warn_once(msg)
V1021 16:37:15.101000 599972 site-packages/torch/_dynamo/guards.py:2842] [11/1] [__recompiles] Recompiling function _lazy_call_fn in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:389
V1021 16:37:15.101000 599972 site-packages/torch/_dynamo/guards.py:2842] [11/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:15.101000 599972 site-packages/torch/_dynamo/guards.py:2842] [11/1] [__recompiles]     - 11/0: L['self'].func_name == 'torchrl.data.replay_buffers.storages.TensorStorage.set'
V1021 16:37:15.108000 599972 site-packages/torch/_dynamo/guards.py:2842] [12/1] [__recompiles] Recompiling function torch_dynamo_resume_in__lazy_call_fn_at_394 in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:394
V1021 16:37:15.108000 599972 site-packages/torch/_dynamo/guards.py:2842] [12/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:15.108000 599972 site-packages/torch/_dynamo/guards.py:2842] [12/1] [__recompiles]     - 12/0: len(L['args']) == 3                                         
---------------------
iteration: 1
len: 1
V1021 16:37:16.300000 599972 site-packages/torch/_dynamo/guards.py:2842] [0/1] [__recompiles] Recompiling function extend_and_sample in /home/endoplasm/tmp/rb_compiled.py:10
V1021 16:37:16.300000 599972 site-packages/torch/_dynamo/guards.py:2842] [0/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.300000 599972 site-packages/torch/_dynamo/guards.py:2842] [0/1] [__recompiles]     - 0/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.327000 599972 site-packages/torch/_dynamo/guards.py:2842] [1/1] [__recompiles] Recompiling function extend in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/replay_buffers.py:610
V1021 16:37:16.327000 599972 site-packages/torch/_dynamo/guards.py:2842] [1/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.327000 599972 site-packages/torch/_dynamo/guards.py:2842] [1/1] [__recompiles]     - 1/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.340000 599972 site-packages/torch/_dynamo/guards.py:2842] [13/1] [__recompiles] Recompiling function set in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:686
V1021 16:37:16.340000 599972 site-packages/torch/_dynamo/guards.py:2842] [13/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.340000 599972 site-packages/torch/_dynamo/guards.py:2842] [13/1] [__recompiles]     - 13/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.357000 599972 site-packages/torch/_dynamo/guards.py:2842] [18/1] [__recompiles] Recompiling function torch_dynamo_resume_in_set_at_713 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:713
V1021 16:37:16.357000 599972 site-packages/torch/_dynamo/guards.py:2842] [18/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.357000 599972 site-packages/torch/_dynamo/guards.py:2842] [18/1] [__recompiles]     - 18/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.406000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/1] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.406000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.406000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/1] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
V1021 16:37:16.423000 599972 site-packages/torch/_dynamo/guards.py:2842] [31/1] [__recompiles] Recompiling function <lambda> in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:802
V1021 16:37:16.423000 599972 site-packages/torch/_dynamo/guards.py:2842] [31/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.423000 599972 site-packages/torch/_dynamo/guards.py:2842] [31/1] [__recompiles]     - 31/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 3
---------------------
iteration: 2
len: 3
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 3
len: 6
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 4
len: 10
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 5
len: 15
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 6
len: 21
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/5: L['___stack2'] == 21  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 7
len: 28
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/6: L['___stack2'] == 28  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/5: L['___stack2'] == 21  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 8
len: 36
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/7: L['___stack2'] == 36  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/6: L['___stack2'] == 28  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/5: L['___stack2'] == 21  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8] torch._dynamo hit config.cache_size_limit (8)
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8]    function: 'torch_dynamo_resume_in__rand_given_ndim_at_152' (/home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152)
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8]    last reason: 28/0: L['___stack2'] == 1                                         
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
---------------------
iteration: 9
len: 45
---------------------
iteration: 10
len: 55
---------------------
iteration: 11
len: 66
---------------------
iteration: 12
len: 78
---------------------
iteration: 13
len: 91
---------------------
iteration: 14
len: 105

Solution

Compiling and calling ReplayBuffer.extend and ReplayBuffer.sample back-to-back should not cause recompilation.

We need to support the base case of torchrl.data.ReplayBuffer(storage=torchrl.data.LazyTensorStorage(1000)), as well as cases where the storage is a LazyMemmapStorage and where the sampler is a SliceSampler.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions