diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 2e42cd7864d..684c4f9901b 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -1040,6 +1040,133 @@ def _make_node_map(self, source, dest): self.max_size = self.data_map.max_size def extend(self, rollout, *, return_node: bool = False): + """Add a rollout to the forest. + + Nodes are only added to a tree at points where rollouts diverge from + each other and at the endpoints of rollouts. + + If there is no existing tree that matches the first steps of the + rollout, a new tree is added. Only one node is created, for the final + step. + + If there is an existing tree that matches, the rollout is added to that + tree. If the rollout diverges from all other rollouts in the tree at + some step, a new node is created before the step where the rollouts + diverge, and a leaf node is created for the final step of the rollout. + If all of the rollout's steps match with a previously added rollout, + nothing changes. If the rollout matches up to a leaf node of a tree but + continues beyond it, that node is extended to the end of the rollout, + and no new nodes are created. + + Args: + rollout (TensorDict): The rollout to add to the forest. + return_node (bool, optional): If ``True``, the method returns the + added node. Default is ``False``. + + Returns: + Tree: The node that was added to the forest. This is only + returned if ``return_node`` is True. + + Examples: + >>> from torchrl.data import MCTSForest + >>> from tensordict import TensorDict + >>> import torch + >>> forest = MCTSForest() + >>> r0 = TensorDict({ + ... 'action': torch.tensor([1, 2, 3, 4, 5]), + ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])}, + ... 'observation': torch.tensor([ 0, 123, 392, 989, 809]) + ... }, [5]) + >>> r1 = TensorDict({ + ... 'action': torch.tensor([1, 2, 6, 7]), + ... 'next': {'observation': torch.tensor([123, 392, 235, 38])}, + ... 'observation': torch.tensor([ 0, 123, 392, 235]) + ... }, [4]) + >>> td_root = r0[0].exclude("next") + >>> forest.extend(r0) + >>> forest.extend(r1) + >>> tree = forest.get_tree(td_root) + >>> print(tree) + Tree( + count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), + index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), + node_data=TensorDict( + fields={ + observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None), + rollout=TensorDict( + fields={ + action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + subtree=Tree( + _parent=NonTensorStack( + [