Skip to content

Packed tokenizer #1473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stanza/models/tokenization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def __init__(self, dataset):
super().__init__()

self.dataset = dataset
self.data, self.indices = sort_with_indices(self.dataset.data, key=len)
self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True)

def __len__(self):
return len(self.data)
Expand Down
15 changes: 9 additions & 6 deletions stanza/models/tokenization/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

class Tokenizer(nn.Module):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
Expand Down Expand Up @@ -41,15 +42,15 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):

self.toknoise = nn.Dropout(self.args['tok_noise'])

def forward(self, x, feats):
def forward(self, x, feats, lengths):
emb = self.embeddings(x)
emb = self.dropout(emb)
feats = self.dropout_feat(feats)


emb = torch.cat([emb, feats], 2)

emb = pack_padded_sequence(emb, lengths, batch_first=True)
inp, _ = self.rnn(emb)
inp, _ = pad_packed_sequence(inp, batch_first=True)

if self.args['conv_res'] is not None:
conv_input = emb.transpose(1, 2).contiguous()
Expand All @@ -73,10 +74,12 @@ def forward(self, x, feats):
mwt0 = self.mwt_clf(inp)

if self.args['hierarchical']:
inp2 = inp
if self.args['hier_invtemp'] > 0:
inp2, _ = self.rnn2(inp * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))))
else:
inp2, _ = self.rnn2(inp)
inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp'])))
inp2 = pack_padded_sequence(inp2, lengths, batch_first=True)
inp2, _ = self.rnn2(inp2)
inp2, _ = pad_packed_sequence(inp2, batch_first=True)

inp2 = self.dropout(inp2)

Expand Down
10 changes: 6 additions & 4 deletions stanza/models/tokenization/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f

def update(self, inputs):
self.model.train()
units, labels, features, _ = inputs
units, labels, features, text = inputs
lengths = [len(x) for x in text]

device = next(self.model.parameters()).device
units = units.to(device)
labels = labels.to(device)
features = features.to(device)

pred = self.model(units, features)
pred = self.model(units, features, lengths)

self.optimizer.zero_grad()
classes = pred.size(2)
Expand All @@ -54,13 +55,14 @@ def update(self, inputs):

def predict(self, inputs):
self.model.eval()
units, _, features, _ = inputs
units, _, features, text = inputs
lengths = [len(x) for x in text]

device = next(self.model.parameters()).device
units = units.to(device)
features = features.to(device)

pred = self.model(units, features)
pred = self.model(units, features, lengths)

return pred.data.cpu().numpy()

Expand Down
4 changes: 2 additions & 2 deletions stanza/models/tokenization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def predict(trainer, data_generator, batch_size, max_seqlen, use_regex_tokens, n
dataloader = TorchDataLoader(sorted_data, batch_size=batch_size, collate_fn=sorted_data.collate, num_workers=num_workers)
for batch_idx, batch in enumerate(dataloader):
num_sentences = len(batch[3])
# being sorted by length, we need to use -1 as the longest sentence
N = len(batch[3][-1])
# being sorted by descending length, we need to use 0 as the longest sentence
N = len(batch[3][0])
for paragraph in batch[3]:
all_raw.append(list(paragraph))

Expand Down