Skip to content

Commit 78a2f3d

Browse files
committed
need a bit more training steps to see results for unet example
1 parent c6f9ccf commit 78a2f3d

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

train_mnist_with_unet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
# constants
2525

2626
IMAGE_AFTER_TEXT = False
27-
NUM_TRAIN_STEPS = 10_000
28-
SAMPLE_EVERY = 250
27+
NUM_TRAIN_STEPS = 20_000
28+
SAMPLE_EVERY = 500
2929

3030
# functions
3131

transfusion_pytorch/transfusion.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import torch
2424
import torch.nn.functional as F
25-
from torch import nn, Tensor, tensor, is_tensor, stack
25+
from torch import nn, Tensor, tensor, is_tensor, cat, stack
2626
from torch.nn import Module, ModuleList, Linear
2727

2828
from torch.utils.data import Dataset, DataLoader
@@ -808,8 +808,8 @@ def forward(
808808

809809
if exists(cache):
810810
cached_k, cached_v = cache
811-
k = torch.cat((cached_k, k), dim = -2)
812-
v = torch.cat((cached_v, v), dim = -2)
811+
k = cat((cached_k, k), dim = -2)
812+
v = cat((cached_v, v), dim = -2)
813813

814814
# maybe kv cache
815815

@@ -1030,7 +1030,7 @@ def forward(
10301030
skip = skips.pop()
10311031

10321032
residual = x
1033-
x = torch.cat((x, skip), dim = -1)
1033+
x = cat((x, skip), dim = -1)
10341034
x = skip_proj(x) + residual
10351035

10361036
# attention and feedforward
@@ -1695,7 +1695,7 @@ def generate_text_only(
16951695

16961696
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
16971697

1698-
out = torch.cat((out, sample), dim = -1)
1698+
out = cat((out, sample), dim = -1)
16991699

17001700
return out[..., prompt_seq_len:]
17011701

@@ -2147,7 +2147,7 @@ def forward(
21472147
precede_modality_tokens = len(modality_meta_info) + 2
21482148
succeed_modality_tokens = 1
21492149

2150-
text_tensor = torch.cat((
2150+
text_tensor = cat((
21512151
tensor_([self.meta_id]),
21522152
modality_meta_info,
21532153
tensor_([som_id]),
@@ -2200,12 +2200,12 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
22002200

22012201
batch_modality_pos_emb.append(pos_emb)
22022202

2203-
text.append(torch.cat(batch_text))
2203+
text.append(cat(batch_text))
22042204

22052205
if need_axial_pos_emb:
2206-
modality_pos_emb.append(torch.cat(batch_modality_pos_emb, dim = -2))
2206+
modality_pos_emb.append(cat(batch_modality_pos_emb, dim = -2))
22072207

2208-
modality_tokens.append(torch.cat(batch_modality_tokens))
2208+
modality_tokens.append(cat(batch_modality_tokens))
22092209
modality_positions.append(batch_modality_positions)
22102210

22112211
modality_index += 1

0 commit comments

Comments
 (0)