Skip to content

Commit 1000e74

Browse files
authored
Merge pull request #152 from afiaka87/truncate_captions
Add --truncate_captions beneath token length arg.
2 parents d56e23a + 049c52a commit 1000e74

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
ff_dropout = 0,
303303
sparse_attn = False,
304304
attn_types = None,
305-
loss_img_weight = 7
305+
loss_img_weight = 7,
306306
):
307307
super().__init__()
308308
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE'

dalle_pytorch/simple_tokenizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def decode(self, tokens):
122122

123123
tokenizer = SimpleTokenizer()
124124

125-
def tokenize(texts, context_length = 256, add_start_and_end = False):
125+
def tokenize(texts, context_length = 256, add_start_and_end = False, truncate_text=False):
126126
if isinstance(texts, str):
127127
texts = [texts]
128128

@@ -133,7 +133,10 @@ def tokenize(texts, context_length = 256, add_start_and_end = False):
133133

134134
for i, tokens in enumerate(all_tokens):
135135
if len(tokens) > context_length:
136-
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
136+
if truncate_text:
137+
tokens = tokens[:context_length]
138+
else:
139+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
137140
result[i, :len(tokens)] = torch.tensor(tokens)
138141

139142
return result

train_dalle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
parser.add_argument('--image_text_folder', type = str, required = True,
3939
help='path to your folder of images and text for learning the DALL-E')
4040

41+
parser.add_argument('--truncate_captions', dest='truncate_captions',
42+
help='Captions passed in which exceed the max token length will be truncated if this is set.')
43+
4144
parser.add_argument('--taming', dest='taming', action='store_true')
4245

4346
parser = deepspeed_utils.wrap_arg_parser(parser)
@@ -197,7 +200,7 @@ def __getitem__(self, ind):
197200
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
198201
description = choice(descriptions)
199202

200-
tokenized_text = tokenize(description, self.text_len).squeeze(0)
203+
tokenized_text = tokenize(description, self.text_len, truncate_text=args.truncate_captions).squeeze(0)
201204
mask = tokenized_text != 0
202205

203206
image_tensor = self.image_tranform(image)

0 commit comments

Comments
 (0)