Skip to content

Commit 22a9c50

Browse files
committed
complete batched preocessing of modality tensors of same shapes for each modality type, if they need initial encoding (latent flow matching)
1 parent 865b893 commit 22a9c50

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,20 @@ loss.backward()
155155
sampled = model.generate_text_only(text[:, :1], 1024)
156156
```
157157

158+
## Examples
159+
160+
To run any of the examples `train_{example_name}.py` in the project root, simply install dependencies first as so
161+
162+
```bash
163+
$ pip install .[examples]
164+
```
165+
166+
If you run into some weird error with `safetensors`, run this too
167+
168+
```bash
169+
$ pip install -U diffusers transformers accelerate scipy ftfy safetensors
170+
```
171+
158172
## Citations
159173

160174
```bibtex

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.9.5"
3+
version = "0.10.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class ModalityInfo(NamedTuple):
113113
eom_id: int
114114
to_shape_fn: Callable | None
115115
channel_first_latent: bool
116+
modality_type: int
116117

117118
# helper functions
118119

@@ -1489,7 +1490,8 @@ def get_modality_info(
14891490
som_id = som_id,
14901491
eom_id = eom_id,
14911492
to_shape_fn = to_shape_fn,
1492-
channel_first_latent = channel_first_latent
1493+
channel_first_latent = channel_first_latent,
1494+
modality_type = modality_type
14931495
)
14941496

14951497
def get_all_modality_info(self) -> list[ModalityInfo]:
@@ -2264,10 +2266,24 @@ def forward(
22642266

22652267
text = []
22662268

2267-
flows = defaultdict(list) # store flows for loss
2269+
# auto move all tensors to device of model
2270+
2271+
modalities = tree_map_tensor(modalities, lambda t: t.to(device))
2272+
2273+
# for all modalities, batch process same shaped modalities of the same type
2274+
2275+
if not is_decoding:
2276+
for mod in self.get_all_modality_info():
2277+
encode_fn = default(mod.encoder, nn.Identity())
2278+
2279+
with torch.no_grad():
2280+
encode_fn.eval()
2281+
modalities = apply_fn_modality_type(encode_fn, modalities, modality_type = mod.modality_type)
22682282

22692283
# for parsing out the predicted flow from flattened sequence of tokens coming out of transformer
22702284

2285+
flows = defaultdict(list) # store flows for loss
2286+
22712287
get_pred_flows: GetPredFlows = defaultdict(list) # functions for parsing modalities from Float['b n d'] for model back to latents or pixel space
22722288

22732289
def model_to_pred_flow(batch_index, start_index, modality_length, unpack_fn):
@@ -2322,22 +2338,13 @@ def inner(pred_flow):
23222338
if is_text:
23232339
modality_tensor = modality
23242340
else:
2325-
modality_type, modality_tensor = modality
2341+
modality_type, modality_tensor, *_ = modality
23262342

23272343
# auto move modality tensor to correct device
23282344

2329-
modality_tensor = modality_tensor.to(device)
2330-
23312345
mod = self.get_modality_info(modality_type)
23322346

23332347
if is_modality:
2334-
if not is_decoding:
2335-
2336-
if exists(mod.encoder):
2337-
with torch.no_grad():
2338-
mod.encoder.eval()
2339-
modality_tensor = self.maybe_add_temp_batch_dim(mod.encoder)(modality_tensor).detach()
2340-
23412348
assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified'
23422349

23432350
channel_dim = 0 if mod.channel_first_latent else -1

0 commit comments

Comments
 (0)