Skip to content

Commit cb9f8a2

Browse files
authored
Merge pull request #275 from ntumlgroup/word_dict_none
Set the default value for `word_dict` and `embed_vecs`
2 parents 2255a70 + 64d215c commit cb9f8a2

File tree

6 files changed

+76
-45
lines changed

6 files changed

+76
-45
lines changed

docs/examples/plot_KimCNN_quickstart.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
datasets = load_datasets('data/rcv1/train.txt', 'data/rcv1/test.txt', tokenize_text=True)
3838
classes = load_or_build_label(datasets)
3939
word_dict, embed_vecs = load_or_build_text_dict(dataset=datasets['train'], embed_file='glove.6B.300d')
40-
tokenizer = None
4140

4241
######################################################################
4342
# Initialize a model
@@ -91,13 +90,12 @@
9190
for split in ['train', 'val', 'test']:
9291
loaders[split] = get_dataset_loader(
9392
data=datasets[split],
94-
word_dict=word_dict,
9593
classes=classes,
9694
device=device,
9795
max_seq_length=512,
9896
batch_size=8,
9997
shuffle=True if split == 'train' else False,
100-
tokenizer=tokenizer
98+
word_dict=word_dict
10199
)
102100

103101
######################################################################
@@ -125,4 +123,3 @@
125123
# 'P@3': 0.7772253751754761,
126124
# 'P@5': 0.5449321269989014,
127125
# }
128-

docs/examples/plot_bert_quickstart.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
######################################################################
1919
# Setup device
2020
# --------------------
21-
# If you need to reproduce the results, please use the function ``set_seed``.
21+
# If you need to reproduce the results, please use the function ``set_seed``.
2222
# For example, you will get the same result as you always use the seed ``1337``.
2323
#
2424
# For initial a hardware device, please use ``init_device`` to assign the hardware device that you want to use.
@@ -29,12 +29,12 @@
2929
######################################################################
3030
# Load and tokenize data
3131
# ------------------------------------------
32-
# We assume that the ``rcv1`` data is located at the directory ``./data/rcv1``,
32+
# We assume that the ``rcv1`` data is located at the directory ``./data/rcv1``,
3333
# and there exist the files ``train.txt`` and ``test.txt``.
34-
# You can utilize the function ``load_datasets()`` to load the data sets.
35-
# By default, LibMultiLabel tokenizes documents, but the BERT model uses its own tokenizer.
34+
# You can utilize the function ``load_datasets()`` to load the data sets.
35+
# By default, LibMultiLabel tokenizes documents, but the BERT model uses its own tokenizer.
3636
# Thus, we must set ``tokenize_text=False``.
37-
# Note that ``datasets`` contains three sets: ``datasets['train']``, ``datasets['val']`` and ``datasets['test']``,
37+
# Note that ``datasets`` contains three sets: ``datasets['train']``, ``datasets['val']`` and ``datasets['test']``,
3838
# where ``datasets['train']`` and ``datasets['val']`` are randomly splitted from ``train.txt`` with the ratio ``8:2``.
3939
#
4040
# For the labels of the data, we apply the function ``load_or_build_label()`` to generate the label set.
@@ -44,7 +44,6 @@
4444

4545
datasets = load_datasets('data/rcv1/train.txt', 'data/rcv1/test.txt', tokenize_text=False)
4646
classes = load_or_build_label(datasets)
47-
word_dict, embed_vecs = None, None
4847
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
4948

5049
######################################################################
@@ -63,8 +62,6 @@
6362
model_name=model_name,
6463
network_config=network_config,
6564
classes=classes,
66-
word_dict=word_dict,
67-
embed_vecs=embed_vecs,
6865
learning_rate=learning_rate,
6966
monitor_metrics=['Micro-F1', 'Macro-F1', 'P@1', 'P@3', 'P@5']
7067
)
@@ -80,7 +77,7 @@
8077
# Initialize a trainer
8178
# ----------------------------
8279
#
83-
# We use the function ``init_trainer`` to initialize a trainer.
80+
# We use the function ``init_trainer`` to initialize a trainer.
8481

8582
trainer = init_trainer(checkpoint_dir='runs/NN-example', epochs=15, val_metric='P@5')
8683

@@ -97,7 +94,6 @@
9794
for split in ['train', 'val', 'test']:
9895
loaders[split] = get_dataset_loader(
9996
data=datasets[split],
100-
word_dict=word_dict,
10197
classes=classes,
10298
device=device,
10399
max_seq_length=512,
@@ -112,7 +108,7 @@
112108
# Train and test a model
113109
# ------------------------------
114110
#
115-
# The bert model training process can be started via
111+
# The bert model training process can be started via
116112

117113
trainer.fit(model, loaders['train'], loaders['val'])
118114

@@ -125,9 +121,9 @@
125121
# The results should be similar to::
126122
#
127123
# {
128-
# 'Macro-F1': 0.569891024909958,
129-
# 'Micro-F1': 0.8142925500869751,
130-
# 'P@1': 0.9552904367446899,
131-
# 'P@3': 0.7907078266143799,
124+
# 'Macro-F1': 0.569891024909958,
125+
# 'Micro-F1': 0.8142925500869751,
126+
# 'P@1': 0.9552904367446899,
127+
# 'P@3': 0.7907078266143799,
132128
# 'P@5': 0.5505486726760864
133-
# }
129+
# }

docs/examples/plot_dataset_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@
6565

6666
from libmultilabel.nn.data_utils import load_datasets
6767

68-
datasets = load_datasets(data_sets['train'], data_sets['test'], tokenize_text=False)
68+
datasets = load_datasets(data_sets['train'], data_sets['test'], tokenize_text=False)

libmultilabel/nn/data_utils.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.preprocessing import MultiLabelBinarizer
1212
from torch.nn.utils.rnn import pad_sequence
1313
from torch.utils.data import Dataset
14-
from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases
14+
from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab
1515
from tqdm import tqdm
1616

1717
transformers.logging.set_verbosity_error()
@@ -22,24 +22,48 @@
2222

2323

2424
class TextDataset(Dataset):
25-
"""Class for text dataset"""
25+
"""Class for text dataset.
2626
27-
def __init__(self, data, word_dict, classes, max_seq_length, tokenizer=None, add_special_tokens=True):
27+
Args:
28+
data (list[dict]): List of instances with index, label, and text.
29+
classes (list): List of labels.
30+
max_seq_length (int, optional): The maximum number of tokens of a sample.
31+
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
32+
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
33+
the transformer-based pretrained language model. Defaults to None.
34+
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
35+
map tokens to indices. Defaults to None.
36+
"""
37+
def __init__(
38+
self,
39+
data,
40+
classes,
41+
max_seq_length,
42+
add_special_tokens=True,
43+
*,
44+
tokenizer=None,
45+
word_dict=None,
46+
):
2847
self.data = data
29-
self.word_dict = word_dict
3048
self.classes = classes
3149
self.max_seq_length = max_seq_length
32-
self.num_classes = len(self.classes)
33-
self.label_binarizer = MultiLabelBinarizer().fit([classes])
50+
self.word_dict = word_dict
3451
self.tokenizer = tokenizer
3552
self.add_special_tokens = add_special_tokens
3653

54+
self.num_classes = len(self.classes)
55+
self.label_binarizer = MultiLabelBinarizer().fit([classes])
56+
57+
if not isinstance(self.word_dict, Vocab) ^ isinstance(
58+
self.tokenizer, transformers.PreTrainedTokenizerBase):
59+
raise ValueError(
60+
'Please specify exactly one of word_dict or tokenizer')
61+
3762
def __len__(self):
3863
return len(self.data)
3964

4065
def __getitem__(self, index):
4166
data = self.data[index]
42-
4367
if self.tokenizer is not None: # transformers tokenizer
4468
if self.add_special_tokens: # tentatively hard code
4569
input_ids = self.tokenizer.encode(data['text'],
@@ -83,35 +107,44 @@ def generate_batch(data_batch):
83107

84108
def get_dataset_loader(
85109
data,
86-
word_dict,
87110
classes,
88111
device,
89112
max_seq_length=500,
90113
batch_size=1,
91114
shuffle=False,
92115
data_workers=4,
116+
add_special_tokens=True,
117+
*,
93118
tokenizer=None,
94-
add_special_tokens=True
119+
word_dict=None,
95120
):
96121
"""Create a pytorch DataLoader.
97122
98123
Args:
99-
data (list): List of training instances with index, label, and tokenized text.
100-
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
124+
data (list[dict]): List of training instances with index, label, and tokenized text.
101125
classes (list): List of labels.
102126
device (torch.device): One of cuda or cpu.
103127
max_seq_length (int, optional): The maximum number of tokens of a sample. Defaults to 500.
104128
batch_size (int, optional): Size of training batches. Defaults to 1.
105129
shuffle (bool, optional): Whether to shuffle training data before each epoch. Defaults to False.
106130
data_workers (int, optional): Use multi-cpu core for data pre-processing. Defaults to 4.
107-
tokenizer (optional): Tokenizer of the transformer-based language model. Defaults to None.
108131
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
132+
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
133+
the transformer-based pretrained language model. Defaults to None.
134+
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
135+
map tokens to indices. Defaults to None.
109136
110137
Returns:
111138
torch.utils.data.DataLoader: A pytorch DataLoader.
112139
"""
113-
dataset = TextDataset(data, word_dict, classes, max_seq_length, tokenizer=tokenizer,
114-
add_special_tokens=add_special_tokens)
140+
dataset = TextDataset(
141+
data,
142+
classes,
143+
max_seq_length,
144+
word_dict=word_dict,
145+
tokenizer=tokenizer,
146+
add_special_tokens=add_special_tokens
147+
)
115148
dataset_loader = torch.utils.data.DataLoader(
116149
dataset,
117150
batch_size=batch_size,

libmultilabel/nn/nn_utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def init_device(use_cpu=False):
3737
def init_model(model_name,
3838
network_config,
3939
classes,
40-
word_dict,
41-
embed_vecs,
40+
word_dict=None,
41+
embed_vecs=None,
4242
init_weight=None,
4343
log_path=None,
4444
learning_rate=0.0001,
@@ -57,8 +57,10 @@ def init_model(model_name,
5757
model_name (str): Model to be used such as KimCNN.
5858
network_config (dict): Configuration for defining the network.
5959
classes (list): List of class names.
60-
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
61-
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
60+
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
61+
map tokens to indices. Defaults to None.
62+
embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape
63+
(vocab_size, embed_dim). Defaults to None.
6264
init_weight (str): Weight initialization method from `torch.nn.init`.
6365
For example, the `init_weight` of `torch.nn.init.kaiming_uniform_`
6466
is `kaiming_uniform`. Defaults to None.
@@ -79,11 +81,14 @@ def init_model(model_name,
7981
Model: A class that implements `MultiLabelModel` for initializing and training a neural network.
8082
"""
8183

82-
network = getattr(networks, model_name)(
83-
embed_vecs=embed_vecs,
84-
num_classes=len(classes),
85-
**dict(network_config)
86-
)
84+
try:
85+
network = getattr(networks, model_name)(
86+
embed_vecs=embed_vecs,
87+
num_classes=len(classes),
88+
**dict(network_config)
89+
)
90+
except:
91+
raise AttributeError(f'Failed to initialize {model_name}.')
8792

8893
if init_weight is not None:
8994
init_weight = networks.get_init_weight_func(

torch_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,13 @@ def _get_dataset_loader(self, split, shuffle=False):
182182
"""
183183
return data_utils.get_dataset_loader(
184184
data=self.datasets[split],
185-
word_dict=self.model.word_dict,
186185
classes=self.model.classes,
187186
device=self.device,
188187
max_seq_length=self.config.max_seq_length,
189188
batch_size=self.config.batch_size if split == 'train' else self.config.eval_batch_size,
190189
shuffle=shuffle,
191190
data_workers=self.config.data_workers,
191+
word_dict=self.model.word_dict,
192192
tokenizer=self.tokenizer,
193193
add_special_tokens=self.config.add_special_tokens
194194
)

0 commit comments

Comments
 (0)