Skip to content

Commit d399781

Browse files
committed
Seems that a full list of PAD isn't needed
... need to use the longest sentence to get the batch length
1 parent aef7260 commit d399781

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

stanza/models/tokenization/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def collate(self, samples):
431431
units[i, :len(u_)] = torch.from_numpy(u_)
432432
labels[i, :len(l_)] = torch.from_numpy(l_)
433433
features[i, :len(f_), :] = torch.from_numpy(f_)
434-
raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))
434+
raw_units.append(r_ + ['<PAD>'])
435435

436436
return units, labels, features, raw_units
437437

stanza/models/tokenization/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ def predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, n
258258
dataloader = TorchDataLoader(sorted_data, batch_size=batch_size, collate_fn=sorted_data.collate, num_workers=num_workers)
259259
for batch_idx, batch in enumerate(dataloader):
260260
num_sentences = len(batch[3])
261-
N = len(batch[3][0])
261+
# being sorted by length, we need to use -1 as the longest sentence
262+
N = len(batch[3][-1])
262263
for paragraph in batch[3]:
263264
all_raw.append(list(paragraph))
264265

0 commit comments

Comments
 (0)