Skip to content

Commit b1d9b44

Browse files
committed
fix generation and further cleanup by moving generate method into Dalle class
1 parent 57f6cf2 commit b1d9b44

File tree

4 files changed

+46
-51
lines changed

4 files changed

+46
-51
lines changed

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,8 @@ loss.backward()
117117
Finally, to generate images
118118

119119
```python
120-
from dalle_pytorch import generate_images
121120

122-
images = generate_images(
123-
dalle,
121+
dalle.generate_images(
124122
vae = vae,
125123
text = text,
126124
mask = mask
@@ -132,10 +130,8 @@ images.shape # (2, 3, 256, 256)
132130
To get the similarity scores from your trained Clipper, just do
133131

134132
```python
135-
from dalle_pytorch import generate_images
136133

137-
images, scores = generate_images(
138-
dalle,
134+
images, scores = dalle.generate_images(
139135
vae = vae,
140136
text = text,
141137
mask = mask,

dalle_pytorch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
2-
from dalle_pytorch.dalle_pytorch import generate_images

dalle_pytorch/dalle_pytorch.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -37,49 +37,6 @@ def top_k(logits, thres = 0.5):
3737
probs.scatter_(1, ind, val)
3838
return probs
3939

40-
@torch.no_grad()
41-
@eval_decorator
42-
def generate_images(
43-
model,
44-
vae,
45-
text,
46-
clipper = None,
47-
mask = None,
48-
filter_thres = 0.5,
49-
temperature = 1.
50-
):
51-
x = text
52-
53-
text_seq_len = model.text_seq_len
54-
image_seq_len = model.image_seq_len
55-
total_len = text_seq_len + model.image_seq_len - text.shape[1]
56-
57-
out = x
58-
for _ in range(total_len):
59-
text, image = x[:, :text_seq_len], x[:, text_seq_len:]
60-
logits = model(text, image, mask = mask)[:, -1, :]
61-
filtered_logits = top_k(logits, thres = filter_thres)
62-
probs = F.softmax(filtered_logits / temperature, dim=-1)
63-
64-
sample = torch.multinomial(probs, 1)
65-
out = torch.cat((out, sample), dim=-1)
66-
67-
if out.shape[1] <= text_seq_len:
68-
mask = F.pad(mask, (0, 1), value=True)
69-
70-
text_seq = torch.cat((x[:, :1], out[:, :(text_seq_len - 1)]), dim = 1)
71-
72-
img_seq = out[:, -image_seq_len:]
73-
img_seq -= model.num_text_tokens
74-
75-
images = vae.decode(img_seq)
76-
77-
if exists(clipper):
78-
scores = clipper(text_seq, images, return_loss = False)
79-
return images, scores
80-
81-
return images
82-
8340
# discrete vae class
8441

8542
class DiscreteVAE(nn.Module):
@@ -304,6 +261,49 @@ def __init__(
304261

305262
self.register_buffer('logits_mask', logits_mask)
306263

264+
@torch.no_grad()
265+
@eval_decorator
266+
def generate_images(
267+
self,
268+
vae,
269+
text,
270+
clipper = None,
271+
mask = None,
272+
filter_thres = 0.5,
273+
temperature = 1.
274+
):
275+
text_seq_len, image_seq_len, num_text_tokens = self.text_seq_len, self.image_seq_len, self.num_text_tokens
276+
total_len = text_seq_len + image_seq_len
277+
278+
out = text
279+
for cur_len in range(text.shape[1], total_len):
280+
is_image = cur_len >= text_seq_len
281+
282+
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
283+
284+
logits = self(text, image, mask = mask)[:, -1, :]
285+
286+
filtered_logits = top_k(logits, thres = filter_thres)
287+
probs = F.softmax(filtered_logits / temperature, dim = -1)
288+
sample = torch.multinomial(probs, 1)
289+
290+
sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
291+
out = torch.cat((out, sample), dim=-1)
292+
293+
if out.shape[1] <= text_seq_len:
294+
mask = F.pad(mask, (0, 1), value = True)
295+
296+
text_seq = out[:, :text_seq_len]
297+
298+
img_seq = out[:, -image_seq_len:]
299+
images = vae.decode(img_seq)
300+
301+
if exists(clipper):
302+
scores = clipper(text_seq, images, return_loss = False)
303+
return images, scores
304+
305+
return images
306+
307307
def forward(
308308
self,
309309
text,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.19',
6+
version = '0.0.21',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)