Skip to content

Commit 0f08a4f

Browse files
authored
Merge pull request #114 from NTMC-Community/dev
Version 1.1.1
2 parents 2d27487 + 068d8ac commit 0f08a4f

File tree

18 files changed

+541
-570
lines changed

18 files changed

+541
-570
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,13 @@ trainset = mz.dataloader.Dataset(
9696
data_pack=train_processed,
9797
mode='pair',
9898
num_dup=1,
99-
num_neg=4
99+
num_neg=4,
100+
batch_size=32
100101
)
101102
validset = mz.dataloader.Dataset(
102103
data_pack=valid_processed,
103-
mode='point'
104+
mode='point',
105+
batch_size=32
104106
)
105107
```
106108

@@ -110,13 +112,11 @@ padding_callback = mz.models.ArcI.get_default_padding_callback()
110112

111113
trainloader = mz.dataloader.DataLoader(
112114
dataset=trainset,
113-
batch_size=32,
114115
stage='train',
115116
callback=padding_callback
116117
)
117118
validloader = mz.dataloader.DataLoader(
118119
dataset=validset,
119-
batch_size=32,
120120
stage='dev',
121121
callback=padding_callback
122122
)
@@ -127,6 +127,8 @@ Initialize the model, fine-tune the hyper-parameters:
127127
```python
128128
model = mz.models.ArcI()
129129
model.params['task'] = ranking_task
130+
model.params['embedding_output_dim'] = 100
131+
model.params['embedding_input_dim'] = preprocessor.context['embedding_input_dim']
130132
model.guess_and_fill_missing_params()
131133
model.build()
132134
```

matchzoo/auto/preparer/preparer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,20 @@ def _build_matrix(self, preprocessor, embedding):
159159
return np.random.uniform(-0.2, 0.2, matrix_shape)
160160

161161
def _build_dataset_builder(self, model, embedding_matrix, preprocessor):
162-
builder_kwargs = dict(callbacks=[])
162+
builder_kwargs = dict(
163+
callbacks=[],
164+
batch_size=self._config['batch_size'],
165+
shuffle=self._config['shuffle'],
166+
sort=self._config['sort']
167+
)
163168

164169
if isinstance(self._task.losses[0], (mz.losses.RankHingeLoss,
165170
mz.losses.RankCrossEntropyLoss)):
166171
builder_kwargs.update(dict(
167172
mode='pair',
168173
num_dup=self._config['num_dup'],
169-
num_neg=self._config['num_neg']
174+
num_neg=self._config['num_neg'],
175+
resample=self._config['resample'],
170176
))
171177

172178
if isinstance(model, mz.models.CDSSM):
@@ -201,11 +207,7 @@ def _build_dataset_builder(self, model, embedding_matrix, preprocessor):
201207

202208
def _build_dataloader_builder(self, model, callback):
203209
builder_kwargs = dict(
204-
batch_size=self._config['batch_size'],
205210
stage=self._config['stage'],
206-
resample=self._config['resample'],
207-
shuffle=self._config['shuffle'],
208-
sort=self._config['sort'],
209211
callback=callback
210212
)
211213
return DataLoaderBuilder(**builder_kwargs)

matchzoo/dataloader/callbacks/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .lambda_callback import LambdaCallback
2-
from .dynamic_pooling import DynamicPooling
32
from .histogram import Histogram
43
from .ngram import Ngram
54
from .padding import BasicPadding

matchzoo/dataloader/callbacks/dynamic_pooling.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

matchzoo/dataloader/callbacks/padding.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
11
import typing
2+
from collections import Iterable
23

34
import numpy as np
45

56
from matchzoo.engine.base_callback import BaseCallback
67

78

9+
def _infer_dtype(value):
10+
"""Infer the dtype for the features.
11+
12+
It is required as the input is usually array of objects before padding.
13+
"""
14+
while isinstance(value, (list, tuple)) and len(value) > 0:
15+
value = value[0]
16+
17+
if not isinstance(value, Iterable):
18+
return np.array(value).dtype
19+
20+
if value is not None and len(value) > 0 and np.issubdtype(
21+
np.array(value).dtype, np.generic):
22+
dtype = np.array(value[0]).dtype
23+
else:
24+
dtype = value.dtype
25+
26+
# Single Precision
27+
if dtype == np.double:
28+
dtype = np.float32
29+
30+
return dtype
31+
32+
833
def _padding_2D(input, output, mode: str = 'pre'):
934
"""
1035
Pad the input 2D-tensor to the output 2D-tensor.
@@ -122,24 +147,26 @@ def on_batch_unpacked(self, x: dict, y: np.ndarray):
122147
pad_length_right = self._fixed_length_right
123148

124149
for key, value in x.items():
150+
dtype = _infer_dtype(value)
151+
125152
if key == 'text_left':
126153
padded_value = np.full([batch_size, pad_length_left],
127-
self._pad_word_value, dtype=value.dtype)
154+
self._pad_word_value, dtype=dtype)
128155
_padding_2D(value, padded_value, self._pad_word_mode)
129156
elif key == 'text_right':
130157
padded_value = np.full([batch_size, pad_length_right],
131-
self._pad_word_value, dtype=value.dtype)
158+
self._pad_word_value, dtype=dtype)
132159
_padding_2D(value, padded_value, self._pad_word_mode)
133160
elif key == 'ngram_left':
134161
padded_value = np.full(
135162
[batch_size, pad_length_left, ngram_length],
136-
self._pad_ngram_value, dtype=value.dtype
163+
self._pad_ngram_value, dtype=dtype
137164
)
138165
_padding_3D(value, padded_value, self._pad_ngram_mode)
139166
elif key == 'ngram_right':
140167
padded_value = np.full(
141168
[batch_size, pad_length_right, ngram_length],
142-
self._pad_ngram_value, dtype=value.dtype
169+
self._pad_ngram_value, dtype=dtype
143170
)
144171
_padding_3D(value, padded_value, self._pad_ngram_mode)
145172
else:
@@ -193,18 +220,21 @@ def on_batch_unpacked(self, x: dict, y: np.ndarray):
193220
if key != 'text_left' and key != 'text_right' and \
194221
key != 'match_histogram':
195222
continue
196-
elif key == 'text_left':
223+
224+
dtype = _infer_dtype(value)
225+
226+
if key == 'text_left':
197227
padded_value = np.full([batch_size, pad_length_left],
198-
self._pad_value, dtype=value.dtype)
228+
self._pad_value, dtype=dtype)
199229
_padding_2D(value, padded_value, self._pad_mode)
200230
elif key == 'text_right':
201231
padded_value = np.full([batch_size, pad_length_right],
202-
self._pad_value, dtype=value.dtype)
232+
self._pad_value, dtype=dtype)
203233
_padding_2D(value, padded_value, self._pad_mode)
204234
else: # key == 'match_histogram'
205235
padded_value = np.full(
206236
[batch_size, pad_length_left, bin_size],
207-
self._pad_value, dtype=value.dtype)
237+
self._pad_value, dtype=dtype)
208238
_padding_3D(value, padded_value, self._pad_mode)
209239
x[key] = padded_value
210240

0 commit comments

Comments
 (0)