Skip to content

[Doc] Add docstring for MCTSForest.extend #2795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,133 @@ def _make_node_map(self, source, dest):
self.max_size = self.data_map.max_size

def extend(self, rollout, *, return_node: bool = False):
"""Add a rollout to the forest.

Nodes are only added to a tree at points where rollouts diverge from
each other and at the endpoints of rollouts.

If there is no existing tree that matches the first steps of the
rollout, a new tree is added. Only one node is created, for the final
step.

If there is an existing tree that matches, the rollout is added to that
tree. If the rollout diverges from all other rollouts in the tree at
some step, a new node is created before the step where the rollouts
diverge, and a leaf node is created for the final step of the rollout.
If all of the rollout's steps match with a previously added rollout,
nothing changes. If the rollout matches up to a leaf node of a tree but
continues beyond it, that node is extended to the end of the rollout,
and no new nodes are created.

Args:
rollout (TensorDict): The rollout to add to the forest.
return_node (bool, optional): If ``True``, the method returns the
added node. Default is ``False``.

Returns:
Tree: The node that was added to the forest. This is only
returned if ``return_node`` is True.

Examples:
>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> import torch
>>> forest = MCTSForest()
>>> r0 = TensorDict({
... 'action': torch.tensor([1, 2, 3, 4, 5]),
... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
... 'observation': torch.tensor([ 0, 123, 392, 989, 809])
... }, [5])
>>> r1 = TensorDict({
... 'action': torch.tensor([1, 2, 6, 7]),
... 'next': {'observation': torch.tensor([123, 392, 235, 38])},
... 'observation': torch.tensor([ 0, 123, 392, 235])
... }, [4])
>>> td_root = r0[0].exclude("next")
>>> forest.extend(r0)
>>> forest.extend(r1)
>>> tree = forest.get_tree(td_root)
>>> print(tree)
Tree(
count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
node_data=TensorDict(
fields={
observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False),
node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
rollout=TensorDict(
fields={
action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([2]),
device=cpu,
is_shared=False),
subtree=Tree(
_parent=NonTensorStack(
[<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
batch_size=torch.Size([2]),
device=None),
count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
hash=NonTensorStack(
[4341220243998689835, 6745467818783115365],
batch_size=torch.Size([2]),
device=None),
node_data=LazyStackedTensorDict(
fields={
observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2]),
device=cpu,
is_shared=False,
stack_dim=0),
node_id=NonTensorStack(
[1, 2],
batch_size=torch.Size([2]),
device=None),
rollout=LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
next: LazyStackedTensorDict(
fields={
observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, -1]),
device=cpu,
is_shared=False,
stack_dim=0),
observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, -1]),
device=cpu,
is_shared=False,
stack_dim=0),
wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
index=None,
subtree=None,
specs=None,
batch_size=torch.Size([2]),
device=None,
is_shared=False),
wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
hash=None,
_parent=None,
specs=None,
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
source, dest = (
rollout.exclude("next").copy(),
rollout.select("next", *self.action_keys).copy(),
Expand Down
Loading