@@ -37,49 +37,6 @@ def top_k(logits, thres = 0.5):
3737 probs .scatter_ (1 , ind , val )
3838 return probs
3939
40- @torch .no_grad ()
41- @eval_decorator
42- def generate_images (
43- model ,
44- vae ,
45- text ,
46- clipper = None ,
47- mask = None ,
48- filter_thres = 0.5 ,
49- temperature = 1.
50- ):
51- x = text
52-
53- text_seq_len = model .text_seq_len
54- image_seq_len = model .image_seq_len
55- total_len = text_seq_len + model .image_seq_len - text .shape [1 ]
56-
57- out = x
58- for _ in range (total_len ):
59- text , image = x [:, :text_seq_len ], x [:, text_seq_len :]
60- logits = model (text , image , mask = mask )[:, - 1 , :]
61- filtered_logits = top_k (logits , thres = filter_thres )
62- probs = F .softmax (filtered_logits / temperature , dim = - 1 )
63-
64- sample = torch .multinomial (probs , 1 )
65- out = torch .cat ((out , sample ), dim = - 1 )
66-
67- if out .shape [1 ] <= text_seq_len :
68- mask = F .pad (mask , (0 , 1 ), value = True )
69-
70- text_seq = torch .cat ((x [:, :1 ], out [:, :(text_seq_len - 1 )]), dim = 1 )
71-
72- img_seq = out [:, - image_seq_len :]
73- img_seq -= model .num_text_tokens
74-
75- images = vae .decode (img_seq )
76-
77- if exists (clipper ):
78- scores = clipper (text_seq , images , return_loss = False )
79- return images , scores
80-
81- return images
82-
8340# discrete vae class
8441
8542class DiscreteVAE (nn .Module ):
@@ -304,6 +261,49 @@ def __init__(
304261
305262 self .register_buffer ('logits_mask' , logits_mask )
306263
264+ @torch .no_grad ()
265+ @eval_decorator
266+ def generate_images (
267+ self ,
268+ vae ,
269+ text ,
270+ clipper = None ,
271+ mask = None ,
272+ filter_thres = 0.5 ,
273+ temperature = 1.
274+ ):
275+ text_seq_len , image_seq_len , num_text_tokens = self .text_seq_len , self .image_seq_len , self .num_text_tokens
276+ total_len = text_seq_len + image_seq_len
277+
278+ out = text
279+ for cur_len in range (text .shape [1 ], total_len ):
280+ is_image = cur_len >= text_seq_len
281+
282+ text , image = out [:, :text_seq_len ], out [:, text_seq_len :]
283+
284+ logits = self (text , image , mask = mask )[:, - 1 , :]
285+
286+ filtered_logits = top_k (logits , thres = filter_thres )
287+ probs = F .softmax (filtered_logits / temperature , dim = - 1 )
288+ sample = torch .multinomial (probs , 1 )
289+
290+ sample -= (num_text_tokens if is_image else 0 ) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
291+ out = torch .cat ((out , sample ), dim = - 1 )
292+
293+ if out .shape [1 ] <= text_seq_len :
294+ mask = F .pad (mask , (0 , 1 ), value = True )
295+
296+ text_seq = out [:, :text_seq_len ]
297+
298+ img_seq = out [:, - image_seq_len :]
299+ images = vae .decode (img_seq )
300+
301+ if exists (clipper ):
302+ scores = clipper (text_seq , images , return_loss = False )
303+ return images , scores
304+
305+ return images
306+
307307 def forward (
308308 self ,
309309 text ,
0 commit comments