Skip to content

Commit 95d85e3

Browse files
committed
first make sure prompting is seamless with text first
1 parent d8523f1 commit 95d85e3

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.5.5"
3+
version = "0.5.6"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train_mnist.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torchvision.transforms as T
1414
from torchvision.utils import save_image
1515

16-
from transfusion_pytorch import Transfusion, print_modality_sample
16+
from transfusion_pytorch.transfusion import Transfusion, print_modality_sample
1717

1818
rmtree('./results', ignore_errors = True)
1919
results_folder = Path('./results')
@@ -24,6 +24,7 @@
2424
IMAGE_AFTER_TEXT = True
2525
NUM_TRAIN_STEPS = 20_000
2626
SAMPLE_EVERY = 500
27+
USE_PROMPT = True
2728
CHANNEL_FIRST = True
2829

2930
# functions
@@ -127,7 +128,24 @@ def collate_fn(data):
127128
# eval
128129

129130
if divisible_by(step, SAMPLE_EVERY):
130-
one_multimodal_sample = ema_model.sample(max_length = 384)
131+
132+
if not USE_PROMPT:
133+
# sampling from start to finish
134+
135+
one_multimodal_sample = ema_model.sample(max_length = 384)
136+
else:
137+
# sampling using prompt
138+
# which differs depending on which comes first, text or images
139+
140+
if IMAGE_AFTER_TEXT:
141+
142+
maybe_label = torch.randint(0, 10, ()).cuda()
143+
one_multimodal_sample = ema_model.sample(prompt = maybe_label, max_length = 384)
144+
145+
else:
146+
raise NotImplementedError
147+
148+
# make sure modality sample overall order of modalities look correct
131149

132150
print_modality_sample(one_multimodal_sample)
133151

transfusion_pytorch/transfusion.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,28 @@ def default_modality_length_to_time_fn(num_modalities: Int['b']) -> Float['b m']
195195

196196
# pretty print
197197

198+
def concat_contiguous_text(
199+
modality_sample: ModalitySample
200+
) -> ModalitySample:
201+
""" within a modality sample, any two tensors of type int / long will be concatted together if next to each other, so all text is followed by a modality, and all modality followed by text """
202+
203+
output = []
204+
curr_modality = None
205+
206+
for modality in modality_sample:
207+
if (
208+
len(output) > 0 and
209+
output[-1].dtype == modality.dtype and
210+
modality.dtype in (torch.int, torch.long)
211+
):
212+
packed_text, _ = pack((output[-1], modality), '*')
213+
output[-1] = packed_text
214+
215+
else:
216+
output.append(modality)
217+
218+
return output
219+
198220
def print_modality_sample(
199221
modality_sample: ModalitySample
200222
):
@@ -1394,12 +1416,18 @@ def sample(
13941416
device = self.device
13951417

13961418
init_text_seq = tensor([self.sos_id], device = device)
1419+
1420+
# just take care of prompt being zero dimensions
1421+
1422+
prompt = tree_map_tensor(prompt, lambda t: rearrange(t, '-> 1') if t.ndim == 0 else t)
1423+
13971424
modality_sample = [init_text_seq, *default(prompt, [])]
13981425

13991426
# take care of moving to device
14001427

14011428
modality_sample = tree_map_tensor(modality_sample, lambda t: t.to(device))
1402-
modality_sample = tree_map_tensor(modality_sample, lambda t: rearrange(t, '-> 1') if t.ndim == 0 else t)
1429+
1430+
modality_sample = concat_contiguous_text(modality_sample)
14031431

14041432
*_, last_modality_sample = modality_sample
14051433
assert last_modality_sample.dtype in (torch.int, torch.long), 'prompt must be text tokens'

0 commit comments

Comments
 (0)