File tree Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -295,14 +295,11 @@ def forward(self, text):
295
295
return_overflowing_tokens = False , padding = "max_length" , return_tensors = "pt" )
296
296
tokens = batch_encoding ["input_ids" ].to (self .device )
297
297
298
- if len (self .modifier_token ) == 3 :
299
- indices = ((tokens == self .modifier_token_id [- 1 ]) | (tokens == self .modifier_token_id [- 2 ]) | (tokens == self .modifier_token_id [- 3 ]))* 1
300
- elif len (self .modifier_token ) == 2 :
301
- indices = ((tokens == self .modifier_token_id [- 1 ]) | (tokens == self .modifier_token_id [- 2 ]))* 1
302
- else :
303
- indices = (tokens == self .modifier_token_id [- 1 ])* 1
298
+ indices = tokens == self .modifier_token_id [- 1 ]
299
+ for token_id in self .modifier_token_id :
300
+ indices |= tokens == token_id
304
301
305
- indices = indices .unsqueeze (- 1 )
302
+ indices = ( indices * 1 ) .unsqueeze (- 1 )
306
303
307
304
input_shape = tokens .size ()
308
305
tokens = tokens .view (- 1 , input_shape [- 1 ])
You can’t perform that action at this time.
0 commit comments