Skip to content

Commit 48f31bf

Browse files
committed
Pack & pad tensors to the LSTM in the tokenizer using PackedSequence
Sort in the other direction means we don't need to use enforce_sorted=False Things are faster without the packed sequences, unfortunately, but they wind up having unstable results: #1472
1 parent 59ebbe0 commit 48f31bf

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

stanza/models/tokenization/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def __init__(self, dataset):
397397
super().__init__()
398398

399399
self.dataset = dataset
400-
self.data, self.indices = sort_with_indices(self.dataset.data, key=len)
400+
self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True)
401401

402402
def __len__(self):
403403
return len(self.data)

stanza/models/tokenization/model.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33
import torch.nn as nn
4+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
45

56
class Tokenizer(nn.Module):
67
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
@@ -41,15 +42,15 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
4142

4243
self.toknoise = nn.Dropout(self.args['tok_noise'])
4344

44-
def forward(self, x, feats):
45+
def forward(self, x, feats, lengths):
4546
emb = self.embeddings(x)
4647
emb = self.dropout(emb)
4748
feats = self.dropout_feat(feats)
4849

49-
5050
emb = torch.cat([emb, feats], 2)
51-
51+
emb = pack_padded_sequence(emb, lengths, batch_first=True)
5252
inp, _ = self.rnn(emb)
53+
inp, _ = pad_packed_sequence(inp, batch_first=True)
5354

5455
if self.args['conv_res'] is not None:
5556
conv_input = emb.transpose(1, 2).contiguous()
@@ -73,10 +74,12 @@ def forward(self, x, feats):
7374
mwt0 = self.mwt_clf(inp)
7475

7576
if self.args['hierarchical']:
77+
inp2 = inp
7678
if self.args['hier_invtemp'] > 0:
77-
inp2, _ = self.rnn2(inp * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))))
78-
else:
79-
inp2, _ = self.rnn2(inp)
79+
inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp'])))
80+
inp2 = pack_padded_sequence(inp2, lengths, batch_first=True)
81+
inp2, _ = self.rnn2(inp2)
82+
inp2, _ = pad_packed_sequence(inp2, batch_first=True)
8083

8184
inp2 = self.dropout(inp2)
8285

stanza/models/tokenization/trainer.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f
3333

3434
def update(self, inputs):
3535
self.model.train()
36-
units, labels, features, _ = inputs
36+
units, labels, features, text = inputs
37+
lengths = [len(x) for x in text]
3738

3839
device = next(self.model.parameters()).device
3940
units = units.to(device)
4041
labels = labels.to(device)
4142
features = features.to(device)
4243

43-
pred = self.model(units, features)
44+
pred = self.model(units, features, lengths)
4445

4546
self.optimizer.zero_grad()
4647
classes = pred.size(2)
@@ -54,13 +55,14 @@ def update(self, inputs):
5455

5556
def predict(self, inputs):
5657
self.model.eval()
57-
units, _, features, _ = inputs
58+
units, _, features, text = inputs
59+
lengths = [len(x) for x in text]
5860

5961
device = next(self.model.parameters()).device
6062
units = units.to(device)
6163
features = features.to(device)
6264

63-
pred = self.model(units, features)
65+
pred = self.model(units, features, lengths)
6466

6567
return pred.data.cpu().numpy()
6668

stanza/models/tokenization/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ def predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, n
258258
dataloader = TorchDataLoader(sorted_data, batch_size=batch_size, collate_fn=sorted_data.collate, num_workers=num_workers)
259259
for batch_idx, batch in enumerate(dataloader):
260260
num_sentences = len(batch[3])
261-
# being sorted by length, we need to use -1 as the longest sentence
262-
N = len(batch[3][-1])
261+
# being sorted by descending length, we need to use 0 as the longest sentence
262+
N = len(batch[3][0])
263263
for paragraph in batch[3]:
264264
all_raw.append(list(paragraph))
265265

0 commit comments

Comments
 (0)