From 950eca32a5ced17a01ed2377143c21d96cc7584b Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 19 Feb 2025 13:36:24 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchrl/data/map/tree.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 7e7567ff974..283bd99bd52 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -1003,6 +1003,33 @@ def _make_node_map(self, source, dest): self.max_size = self.data_map.max_size def extend(self, rollout, *, return_node: bool = False): + """Add a rollout to the forest. + + Nodes are only added to a tree at points where rollouts diverge from + each other and at the endpoints of rollouts. + + If there is no existing tree that matches the first steps of the + rollout, a new tree is added. Only one node is created, for the final + step. + + If there is an existing tree that matches, the rollout is added to that + tree. If the rollout diverges from all other rollouts in the tree at + some step, a new node is created before the step where the rollouts + diverge, and a leaf node is created for the final step of the rollout. + If all of the rollout's steps match with a previously added rollout, + nothing changes. If the rollout matches up to a leaf node of a tree but + continues beyond it, that node is extended to the end of the rollout, + and no new nodes are created. + + Args: + rollout (TensorDict): The rollout to add to the forest. + return_node (bool, optional): If True, the method returns the added + node. Default is ``False``. + + Returns: + Tree: The node that was added to the forest. This is only + returned if ``return_node`` is True. + """ source, dest = ( rollout.exclude("next").copy(), rollout.select("next", *self.action_keys).copy(), From 248c8714536340009a597358caeb7a803f2e3ead Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Mon, 24 Feb 2025 23:27:27 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchrl/data/map/tree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index ccd3fb6c7bc..684c4f9901b 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -1060,8 +1060,8 @@ def extend(self, rollout, *, return_node: bool = False): Args: rollout (TensorDict): The rollout to add to the forest. - return_node (bool, optional): If True, the method returns the added - node. Default is ``False``. + return_node (bool, optional): If ``True``, the method returns the + added node. Default is ``False``. Returns: Tree: The node that was added to the forest. This is only