Skip to content

Commit 55e3bbe

Browse files
committed
clear up confusion from direct correspondance with a phd student
1 parent a80532c commit 55e3bbe

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.9.3"
3+
version = "0.9.4"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,10 @@ def __init__(
12171217

12181218
self.fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
12191219

1220+
# store number of text tokens
1221+
1222+
self.num_text_tokens = num_text_tokens
1223+
12201224
# entire "sentence" start and end id
12211225

12221226
num_text_special_ids = 2
@@ -1442,6 +1446,13 @@ def sample(
14421446

14431447
device = self.device
14441448

1449+
# handle edge case where there are no text tokens
1450+
1451+
if self.num_text_tokens == 0:
1452+
logger.warning(f'you have `num_text_tokens` set to 0, so `sample` will be forwarded to `generate_modality_only(batch_size: int, modality_type: int)` method')
1453+
1454+
return self.generate_modality_only(batch_size = 1)
1455+
14451456
# take care of prompt being a raw tensor, either text or raw modality (image, video, actions, latents, etc)
14461457

14471458
if is_tensor(prompt) and prompt.dtype == torch.float: # is modality with type 0 implicit

0 commit comments

Comments
 (0)