Skip to content

Commit 4e27bfc

Browse files
Nupur KumariNupur Kumari
authored andcommitted
diffusers==0.14.0 update
1 parent 32786fb commit 4e27bfc

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

src/custom_modules.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,11 @@ def forward(self, text):
295295
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
296296
tokens = batch_encoding["input_ids"].to(self.device)
297297

298-
if len(self.modifier_token) == 3:
299-
indices = ((tokens == self.modifier_token_id[-1]) | (tokens == self.modifier_token_id[-2]) | (tokens == self.modifier_token_id[-3]))*1
300-
elif len(self.modifier_token) == 2:
301-
indices = ((tokens == self.modifier_token_id[-1]) | (tokens == self.modifier_token_id[-2]))*1
302-
else:
303-
indices = (tokens == self.modifier_token_id[-1])*1
298+
indices = tokens == self.modifier_token_id[-1]
299+
for token_id in self.modifier_token_id:
300+
indices |= tokens == token_id
304301

305-
indices = indices.unsqueeze(-1)
302+
indices = (indices*1).unsqueeze(-1)
306303

307304
input_shape = tokens.size()
308305
tokens = tokens.view(-1, input_shape[-1])

0 commit comments

Comments
 (0)