Skip to content

Commit bc5f10f

Browse files
committed
make sure captioning works with mnist example, allow for single modality prompting
1 parent 95d85e3 commit bc5f10f

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.5.6"
3+
version = "0.6.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train_mnist.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
# constants
2323

24-
IMAGE_AFTER_TEXT = True
24+
IMAGE_AFTER_TEXT = True # False for captioning, True for text-to-image
25+
USE_PROMPT = False # whether to use prompting, or synthesize from start token
2526
NUM_TRAIN_STEPS = 20_000
26-
SAMPLE_EVERY = 500
27-
USE_PROMPT = True
27+
SAMPLE_EVERY = 250
2828
CHANNEL_FIRST = True
2929

3030
# functions
@@ -49,7 +49,7 @@ def forward(self, x):
4949
if CHANNEL_FIRST:
5050
x = rearrange(x, 'b d ... -> b ... d')
5151

52-
x = rearrange(x, '... h w (p1 p2) -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14)
52+
x = rearrange(x, '... h w (p1 p2) -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2)
5353
return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
5454

5555
model = Transfusion(
@@ -139,11 +139,15 @@ def collate_fn(data):
139139

140140
if IMAGE_AFTER_TEXT:
141141

142-
maybe_label = torch.randint(0, 10, ()).cuda()
143-
one_multimodal_sample = ema_model.sample(prompt = maybe_label, max_length = 384)
142+
text_label = torch.randint(0, 10, ()).cuda()
143+
one_multimodal_sample = ema_model.sample(prompt = text_label, max_length = 384)
144144

145145
else:
146-
raise NotImplementedError
146+
147+
rand_batch = next(iter_dl)
148+
rand_image = rand_batch[0][0]
149+
150+
one_multimodal_sample = ema_model.sample(prompt = rand_image, max_length = 384)
147151

148152
# make sure modality sample overall order of modalities look correct
149153

transfusion_pytorch/transfusion.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ def concat_contiguous_text(
201201
""" within a modality sample, any two tensors of type int / long will be concatted together if next to each other, so all text is followed by a modality, and all modality followed by text """
202202

203203
output = []
204-
curr_modality = None
205204

206205
for modality in modality_sample:
207206
if (
208207
len(output) > 0 and
208+
is_tensor(output[-1]) and is_tensor(modality) and
209209
output[-1].dtype == modality.dtype and
210210
modality.dtype in (torch.int, torch.long)
211211
):
@@ -1365,6 +1365,19 @@ def get_modality_info(
13651365
def get_all_modality_info(self) -> list[ModalityInfo]:
13661366
return [self.get_modality_info(i) for i in range(self.num_modalities)]
13671367

1368+
def get_modality_shape(
1369+
self,
1370+
modality: Float['...'],
1371+
modality_type: int | None = None
1372+
) -> tuple[int, ...]:
1373+
1374+
mod = self.get_modality_info(modality_type)
1375+
1376+
if mod.channel_first_latent:
1377+
modality = rearrange(modality, 'c ... -> ... c')
1378+
1379+
return tuple(modality.shape[:-1])
1380+
13681381
def parameters_without_encoder_decoder(self):
13691382
return (
13701383
set(self.parameters()) -
@@ -1402,7 +1415,7 @@ def create_ema(
14021415
@typecheck
14031416
def sample(
14041417
self,
1405-
prompt: ModalitySample | None = None,
1418+
prompt: ModalitySample | Tensor | tuple[int, Float['...']] | None = None,
14061419
max_length = 2048,
14071420
text_temperature = 1.5,
14081421
text_min_p = 0.1,
@@ -1415,22 +1428,52 @@ def sample(
14151428

14161429
device = self.device
14171430

1431+
# take care of prompt being a raw tensor, either text or raw modality (image, video, actions, latents, etc)
1432+
1433+
if is_tensor(prompt) and prompt.dtype == torch.float: # is modality with type 0 implicit
1434+
prompt = (0, prompt)
1435+
1436+
if is_tensor(prompt) and prompt.dtype in (torch.int, torch.long): # is text only prompt
1437+
prompt = [prompt]
1438+
1439+
elif isinstance(prompt, tuple):
1440+
modality_type, modality = prompt
1441+
1442+
mod = self.get_modality_info(modality_type)
1443+
1444+
if exists(mod.encoder):
1445+
with torch.no_grad():
1446+
mod.encoder.eval()
1447+
modality = self.maybe_add_temp_batch_dim(mod.encoder)(modality).detach()
1448+
1449+
modality_shape_tuple = self.get_modality_shape(modality, modality_type)
1450+
modality_shape_str = join([*map(str, modality_shape_tuple)], ',')
1451+
modality_meta_info = self.char_tokenizer(modality_shape_str, device = device)
1452+
1453+
prompt = [
1454+
tensor([self.meta_id]),
1455+
modality_meta_info,
1456+
tensor([mod.som_id]),
1457+
(modality_type, modality),
1458+
tensor([mod.eom_id]),
1459+
]
1460+
1461+
# sos
1462+
14181463
init_text_seq = tensor([self.sos_id], device = device)
14191464

14201465
# just take care of prompt being zero dimensions
14211466

1422-
prompt = tree_map_tensor(prompt, lambda t: rearrange(t, '-> 1') if t.ndim == 0 else t)
1423-
14241467
modality_sample = [init_text_seq, *default(prompt, [])]
14251468

14261469
# take care of moving to device
14271470

14281471
modality_sample = tree_map_tensor(modality_sample, lambda t: t.to(device))
1472+
modality_sample = tree_map_tensor(modality_sample, lambda t: rearrange(t, '-> 1') if t.ndim == 0 else t)
14291473

14301474
modality_sample = concat_contiguous_text(modality_sample)
14311475

14321476
*_, last_modality_sample = modality_sample
1433-
assert last_modality_sample.dtype in (torch.int, torch.long), 'prompt must be text tokens'
14341477

14351478
curr_length = 0
14361479
curr_modality_id = None

0 commit comments

Comments
 (0)