Skip to content

Commit ba8be9c

Browse files
authored
[BugFix] Tree make node fix (#2839)
1 parent 4c55b65 commit ba8be9c

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

test/test_storage_map.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def test_edges(self):
350350
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351351
assert edges == edges_check
352352

353+
def test_make_node(self):
354+
td = TensorDict({"obs": torch.tensor([0])})
355+
tree = Tree(node_data=td)
356+
assert tree.node_data is not None
357+
358+
tree = Tree.make_node(data=td)
359+
assert tree.node_data is not None
360+
361+
tree = Tree.make_node(td)
362+
assert tree.node_data is not None
363+
353364

354365
class TestMCTSForest:
355366
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:

torchrl/data/llm/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
14+
from .utils import (
15+
AdaptiveKLController,
16+
ConstantKLController,
17+
LLMData,
18+
LLMInput,
19+
LLMOutput,
20+
RolloutFromModel,
21+
)
1522

1623
__all__ = [
1724
"AdaptiveKLController",

torchrl/data/llm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,10 @@ def step_scheduler(self):
543543
while len(self._kl_queue):
544544
self._kl_queue.remove(self._kl_queue[0])
545545

546+
546547
LLMInpOut = TypeVar("LLMInpOut")
547548

549+
548550
class LLMInput(TensorClass["nocast"]):
549551
"""Represents the input to a Large Language Model (LLM).
550552
@@ -557,11 +559,13 @@ class LLMInput(TensorClass["nocast"]):
557559
.. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`.
558560
559561
"""
562+
560563
tokens: torch.Tensor
561564
attention_mask: torch.Tensor | None = None
562565
token_list: list[int] | list[list[int]] | None = None
563566
text: str | list[str] | None = None
564567

568+
565569
class LLMOutput(TensorClass["nocast"]):
566570
"""Represents the output from a Large Language Model (LLM).
567571
@@ -581,6 +585,7 @@ class LLMOutput(TensorClass["nocast"]):
581585
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`.
582586
583587
"""
588+
584589
tokens: torch.Tensor
585590
tokens_response: torch.Tensor | None = None
586591
token_list: list[int] | list[list[int]] | None = None
@@ -594,6 +599,7 @@ def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut:
594599
# placeholder
595600
raise NotImplementedError
596601

602+
597603
class LLMData(TensorClass["nocast"]):
598604
"""Represents the input or output of a Large Language Model (LLM).
599605
@@ -619,6 +625,7 @@ class LLMData(TensorClass["nocast"]):
619625
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`.
620626
621627
"""
628+
622629
tokens: torch.Tensor
623630
tokens_response: torch.Tensor | None = None
624631
attention_mask: torch.Tensor | None = None

torchrl/data/map/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def make_node(
122122
return cls(
123123
count=torch.zeros(()),
124124
wins=torch.zeros(()),
125-
node=data.exclude("action", "next"),
125+
node_data=data.exclude("action", "next"),
126126
rollout=rollout,
127127
subtree=subtree,
128128
device=device,

0 commit comments

Comments
 (0)