Skip to content

Commit afdd146

Browse files
committed
rely on all zeros in the feature dimension to be treated as padding to be masked out, to remove need to keep track of mask. also always project the text encodings, so in the case of multiple text models, acts like a model type embedding.
1 parent dbc2fc7 commit afdd146

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ first_conditioned = first_condition_fn(first_hidden)
4848
second_conditioned = second_condition_fn(second_hidden)
4949
```
5050

51+
If you wish to use cross attention based conditioning (each hidden feature in your network can attend to individual subword tokens), just import the `AttentionTextConditioner` instead. Rest is the same
52+
53+
```python
54+
from classifier_free_guidance_pytorch import AttentionTextConditioner
55+
```
56+
5157
## Magic Decorator (wip)
5258

5359
This is a work in progress to make it as easy as possible to text condition your network.

classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def forward(
271271
hiddens,
272272
mask = None
273273
):
274-
return self.attn(hiddens, condition, mask = mask)
274+
return self.attn(hiddens, condition, mask = mask) + hiddens
275275

276276
# film text conditioning
277277

@@ -440,7 +440,7 @@ def __init__(
440440
dim_latent = default(dim_latent, max([model.dim_latent for model in text_models]))
441441

442442
for model in text_models:
443-
self.to_latent_dims.append(nn.Linear(model.dim_latent, dim_latent) if model.dim_latent != dim_latent else nn.Identity())
443+
self.to_latent_dims.append(nn.Linear(model.dim_latent, dim_latent))
444444

445445
self.conditioners = nn.ModuleList([])
446446

@@ -465,25 +465,25 @@ def embed_texts(self, texts: List[str]):
465465
device = self.device
466466

467467
text_embeds = []
468-
masks = []
469468

470469
for text_model, to_latent in zip(self.text_models, self.to_latent_dims):
471470
text_embed = text_model.embed_text(texts, return_text_encodings = True)
472471

473472
text_embed = text_embed.to(device)
474473

475474
mask = (text_embed != 0).any(dim = -1)
476-
mask = mask.to(device)
477475

478-
text_embeds.append(to_latent(text_embed))
479-
masks.append(mask)
476+
text_embed = to_latent(text_embed)
477+
text_embed = text_embed.masked_fill(~mask[..., None], 0.)
480478

481-
return torch.cat(text_embeds, dim = -2), torch.cat(masks, dim = -1)
479+
text_embeds.append(text_embed)
480+
481+
return torch.cat(text_embeds, dim = -2)
482482

483483
def forward(
484484
self,
485485
texts: Optional[List[str]] = None,
486-
text_embeds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
486+
text_embeds: Optional[torch.Tensor] = None,
487487
cond_drop_prob = None,
488488
repeat_batch = 1, # for robotic transformer edge case
489489
) -> Tuple[Callable, ...]:
@@ -497,14 +497,14 @@ def forward(
497497

498498
if exists(texts):
499499
batch = len(texts)
500-
elif exists(text_embeds):
501-
batch = text_embeds[0].shape[0]
502500

503-
if exists(text_embeds):
504-
text_embeds, mask = text_embeds
501+
elif exists(text_embeds):
502+
batch = text_embeds.shape[0]
505503

506504
if not exists(text_embeds):
507-
text_embeds, mask = self.embed_texts(texts)
505+
text_embeds = self.embed_texts(texts)
506+
507+
mask = (text_embeds != 0).any(dim = -1)
508508

509509
if cond_drop_prob > 0.:
510510
prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'classifier-free-guidance-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.0',
7+
version = '0.1.2',
88
license='MIT',
99
description = 'Classifier Free Guidance - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)