Skip to content

Commit d2cc15e

Browse files
committed
Added option of multi tabular components. Runs with model fusion. More tests of course needed
1 parent 7ef0368 commit d2cc15e

File tree

6 files changed

+104
-59
lines changed

6 files changed

+104
-59
lines changed

pytorch_widedeep/models/model_fusion.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(
110110
"max",
111111
"sum",
112112
"mult",
113+
"dot",
113114
"head",
114115
],
115116
List[Literal["concatenate", "mean", "max", "sum", "mult", "head"]],
@@ -182,11 +183,20 @@ def forward(self, X: List[Tensor]) -> Tensor: # noqa: C901
182183
return self.head(
183184
torch.cat([model(x) for model, x in zip(self.models, X)], -1)
184185
)
186+
elif self.fusion_method == "dot":
187+
assert len(X) == 2, (
188+
"When using 'dot' as fusion_method, only two models "
189+
" can be fused. Accordingly, only two inputs should be provided"
190+
)
191+
outputs = [model(x) for model, x in zip(self.models, X)]
192+
return torch.bmm(outputs[1].unsqueeze(1), outputs[0].unsqueeze(2)).view(
193+
-1, 1
194+
)
185195
else:
186196
if isinstance(self.fusion_method, str):
187197
fusion_methods = [self.fusion_method]
188198
else:
189-
fusion_methods = self.fusion_method
199+
fusion_methods = self.fusion_method # type: ignore
190200

191201
fused_outputs: List[Tensor] = []
192202
for fm in fusion_methods:
@@ -210,7 +220,7 @@ def forward(self, X: List[Tensor]) -> Tensor: # noqa: C901
210220
else:
211221
# This should never happen, but avoids type errors
212222
raise ValueError(
213-
"fusion_method must be one of ['concatenate', 'mean', 'max', 'sum', 'mult', 'head'] "
223+
"fusion_method must be one of ['concatenate', 'mean', 'max', 'sum', 'mult', 'dot', 'head'] "
214224
"or a list of those"
215225
)
216226
fused_outputs.append(out)
@@ -260,12 +270,14 @@ def output_dim(self) -> int:
260270
if hasattr(self, "head_hidden_dims")
261271
else self.head.output_dim
262272
)
273+
elif self.fusion_method == "dot":
274+
output_dim = 1
263275
else:
264276
output_dim = 0
265277
if isinstance(self.fusion_method, str):
266278
fusion_methods = [self.fusion_method]
267279
else:
268-
fusion_methods = self.fusion_method
280+
fusion_methods = self.fusion_method # type: ignore
269281
for fm in fusion_methods:
270282
if fm == "concatenate":
271283
output_dim += sum([model.output_dim for model in self.models])
@@ -291,11 +303,20 @@ def check_input_parameters(self): # noqa: C901
291303
if isinstance(self.fusion_method, str):
292304
if not any(
293305
x == self.fusion_method
294-
for x in ["concatenate", "min", "max", "mean", "sum", "mult", "head"]
306+
for x in [
307+
"concatenate",
308+
"min",
309+
"max",
310+
"mean",
311+
"sum",
312+
"dot",
313+
"mult",
314+
"head",
315+
]
295316
):
296317
raise ValueError(
297-
"fusion_method must be one of ['concatenate', 'mean', 'max', 'sum', 'mult', 'head'] "
298-
"or a list of those"
318+
"fusion_method must be one of ['concatenate', 'mean', 'max', 'sum', 'mult', 'dot', 'head'] "
319+
"or a list of any those but 'dot'"
299320
)
300321

301322
if (
@@ -323,14 +344,15 @@ def check_input_parameters(self): # noqa: C901
323344
"mean",
324345
"sum",
325346
"mult",
347+
"dot",
326348
"head",
327349
]
328350
)
329351
for fm in self.fusion_method
330352
):
331353
raise ValueError(
332-
"fusion_method must be one of ['concatenate', 'mean', 'max', 'sum', 'mult', 'head'] "
333-
"or a list of those"
354+
"fusion_method must be one of ['concatenate', 'mean', 'max', 'sum', 'mult', 'dot', 'head'] "
355+
"or a list of those but 'dot'"
334356
)
335357

336358
if (
@@ -351,9 +373,11 @@ def check_input_parameters(self): # noqa: C901
351373
"projection_method must be one of ['min', 'max', 'mean']"
352374
)
353375

354-
if "head" in self.fusion_method and isinstance(self.fusion_method, list):
376+
if any(x in self.fusion_method for x in ["head", "dot"]) and isinstance(
377+
self.fusion_method, list
378+
):
355379
raise ValueError(
356-
"When using 'head' as fusion_method, no other method should be provided"
380+
"When using 'head' or 'dot' as fusion_method, no other method should be provided"
357381
)
358382

359383
def __repr__(self):

pytorch_widedeep/models/text/rnns/basic_rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import torch
5-
from torch import nn, lstm
5+
from torch import nn
66

77
from pytorch_widedeep.wdtypes import (
88
List,

pytorch_widedeep/models/wide_deep.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,13 @@ def _set_model_component(
413413
if isinstance(component, list):
414414
component_: Optional[Union[nn.ModuleList, WDModel]] = nn.ModuleList()
415415
for cp in component:
416-
if self.with_deephead:
416+
if self.with_deephead or cp.output_dim == 1:
417417
component_.append(cp)
418418
else:
419419
component_.append(
420420
nn.Sequential(cp, nn.Linear(cp.output_dim, self.pred_dim))
421421
)
422-
elif self.with_deephead:
422+
elif self.with_deephead or component.output_dim == 1:
423423
component_ = component
424424
elif is_deeptabular and self.is_tabnet:
425425
component_ = nn.Sequential(
@@ -463,6 +463,11 @@ def _check_inputs( # noqa: C901
463463
else:
464464
if not hasattr(deeptabular, "output_dim"):
465465
raise AttributeError(err_msg)
466+
# the following assertion is thought for those cases where we
467+
# use fusion with 'dot product' so that the output_dim will
468+
# be 1 and the pred_dim is not 1
469+
if deeptabular.output_dim == 1:
470+
assert pred_dim == 1, "If 'output_dim' is 1, 'pred_dim' must be 1"
466471

467472
if deeptabular is not None:
468473
is_tabnet = False
@@ -502,6 +507,8 @@ def _check_inputs( # noqa: C901
502507
else:
503508
if not hasattr(deeptext, "output_dim"):
504509
raise AttributeError(err_msg)
510+
if deeptext.output_dim == 1:
511+
assert pred_dim == 1, "If 'output_dim' is 1, 'pred_dim' must be 1"
505512

506513
if deepimage is not None:
507514
err_msg = "deepimage model must have an 'output_dim' attribute or property."
@@ -512,6 +519,8 @@ def _check_inputs( # noqa: C901
512519
else:
513520
if not hasattr(deepimage, "output_dim"):
514521
raise AttributeError(err_msg)
522+
if deepimage.output_dim == 1:
523+
assert pred_dim == 1, "If 'output_dim' is 1, 'pred_dim' must be 1"
515524

516525
if deephead is not None and head_hidden_dims is not None:
517526
raise ValueError(

pytorch_widedeep/training/_trainer_utils.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
FocalR_RMSELoss,
2323
)
2424
from pytorch_widedeep.wdtypes import (
25+
Any,
2526
Dict,
2627
List,
28+
Tuple,
2729
Union,
2830
Compose,
2931
Literal,
@@ -115,7 +117,7 @@ def wd_train_val_split( # noqa: C901
115117
seed: int,
116118
method: Literal["regression", "binary", "multiclass", "qregression"],
117119
X_wide: Optional[np.ndarray] = None,
118-
X_tab: Optional[np.ndarray] = None,
120+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
119121
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
120122
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
121123
X_train: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
@@ -174,6 +176,7 @@ def wd_train_val_split( # noqa: C901
174176
target is not None
175177
), "if the validation split is specified, the target must also be specified"
176178
X_train = _build_train_dict(X_wide, X_tab, X_text, X_img, target)
179+
177180
y_tr, y_val, idx_tr, idx_val = train_test_split(
178181
X_train["target"],
179182
np.arange(len(X_train["target"])),
@@ -187,46 +190,23 @@ def wd_train_val_split( # noqa: C901
187190
)
188191
X_tr, X_val = {"target": y_tr}, {"target": y_val}
189192
if "X_wide" in X_train.keys():
190-
X_tr["X_wide"], X_val["X_wide"] = (
191-
X_train["X_wide"][idx_tr],
192-
X_train["X_wide"][idx_val],
193+
# the wide component will never be a list, but can still be passed
194+
# to '_wd_train_val_split_component'
195+
X_tr, X_val = _wd_train_val_split_component(
196+
X_train, X_tr, X_val, idx_tr, idx_val, "X_wide"
193197
)
194198
if "X_tab" in X_train.keys():
195-
X_tr["X_tab"], X_val["X_tab"] = (
196-
X_train["X_tab"][idx_tr],
197-
X_train["X_tab"][idx_val],
199+
X_tr, X_val = _wd_train_val_split_component(
200+
X_train, X_tr, X_val, idx_tr, idx_val, "X_tab"
198201
)
199202
if "X_text" in X_train.keys():
200-
if isinstance(X_train["X_text"], list):
201-
X_tr["X_text"], X_val["X_text"] = (
202-
[
203-
X_train["X_text"][i][idx_tr]
204-
for i in range(len(X_train["X_text"]))
205-
],
206-
[
207-
X_train["X_text"][i][idx_val]
208-
for i in range(len(X_train["X_text"]))
209-
],
210-
)
211-
else:
212-
X_tr["X_text"], X_val["X_text"] = (
213-
X_train["X_text"][idx_tr],
214-
X_train["X_text"][idx_val],
215-
)
203+
X_tr, X_val = _wd_train_val_split_component(
204+
X_train, X_tr, X_val, idx_tr, idx_val, "X_text"
205+
)
216206
if "X_img" in X_train.keys():
217-
if isinstance(X_train["X_img"], list):
218-
X_tr["X_img"], X_val["X_img"] = (
219-
[X_train["X_img"][i][idx_tr] for i in range(len(X_train["X_img"]))],
220-
[
221-
X_train["X_img"][i][idx_val]
222-
for i in range(len(X_train["X_img"]))
223-
],
224-
)
225-
else:
226-
X_tr["X_img"], X_val["X_img"] = (
227-
X_train["X_img"][idx_tr],
228-
X_train["X_img"][idx_val],
229-
)
207+
X_tr, X_val = _wd_train_val_split_component(
208+
X_train, X_tr, X_val, idx_tr, idx_val, "X_img"
209+
)
230210
train_set = WideDeepDataset(**X_tr, transforms=transforms) # type: ignore
231211
eval_set = WideDeepDataset(**X_val, transforms=transforms) # type: ignore
232212
else:
@@ -239,9 +219,34 @@ def wd_train_val_split( # noqa: C901
239219
return train_set, eval_set
240220

241221

222+
def _wd_train_val_split_component(
223+
X: Dict[str, Union[np.ndarray, List[np.ndarray]]],
224+
X_tr: Dict[str, Union[np.ndarray, List[np.ndarray]]],
225+
X_val: Dict[str, Union[np.ndarray, List[np.ndarray]]],
226+
idx_tr: Any, # is a numpy array but sklearn's train_test_split returns a non-sensical type
227+
idx_val: Any,
228+
component_type: Literal["X_wide", "X_tab", "X_text", "X_img"],
229+
) -> Tuple[
230+
Dict[str, Union[np.ndarray, List[np.ndarray]]],
231+
Dict[str, Union[np.ndarray, List[np.ndarray]]],
232+
]:
233+
if isinstance(X[component_type], list):
234+
X_tr[component_type], X_val[component_type] = (
235+
[X[component_type][i][idx_tr] for i in range(len(X[component_type]))],
236+
[X[component_type][i][idx_val] for i in range(len(X[component_type]))],
237+
)
238+
else:
239+
X_tr[component_type], X_val[component_type] = (
240+
X[component_type][idx_tr],
241+
X[component_type][idx_val],
242+
)
243+
244+
return X_tr, X_val
245+
246+
242247
def _build_train_dict(
243248
X_wide: Optional[np.ndarray],
244-
X_tab: Optional[np.ndarray],
249+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]],
245250
X_text: Optional[Union[np.ndarray, List[np.ndarray]]],
246251
X_img: Optional[Union[np.ndarray, List[np.ndarray]]],
247252
target: np.ndarray,

pytorch_widedeep/training/_wd_dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class WideDeepDataset(Dataset):
3232
def __init__(
3333
self,
3434
X_wide: Optional[np.ndarray] = None,
35-
X_tab: Optional[np.ndarray] = None,
35+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
3636
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
3737
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
3838
target: Optional[np.ndarray] = None,
@@ -60,7 +60,10 @@ def __getitem__(self, idx: int): # noqa: C901
6060
if self.X_wide is not None:
6161
x.wide = self.X_wide[idx]
6262
if self.X_tab is not None:
63-
x.deeptabular = self.X_tab[idx]
63+
if isinstance(self.X_tab, list):
64+
x.deeptabular = [self.X_tab[i][idx] for i in range(len(self.X_tab))]
65+
else:
66+
x.deeptabular = self.X_tab[idx]
6467
if self.X_text is not None:
6568
if isinstance(self.X_text, list):
6669
x.deeptext = [self.X_text[i][idx] for i in range(len(self.X_text))]
@@ -112,7 +115,10 @@ def __len__(self):
112115
if self.X_wide is not None:
113116
return len(self.X_wide)
114117
if self.X_tab is not None:
115-
return len(self.X_tab)
118+
if isinstance(self.X_tab, list):
119+
return len(self.X_tab[0])
120+
else:
121+
return len(self.X_tab)
116122
if self.X_text is not None:
117123
if isinstance(self.X_text, list):
118124
return len(self.X_text[0])

pytorch_widedeep/training/trainer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270
def fit( # noqa: C901
271271
self,
272272
X_wide: Optional[np.ndarray] = None,
273-
X_tab: Optional[np.ndarray] = None,
273+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
274274
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
275275
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
276276
X_train: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
@@ -529,7 +529,7 @@ def fit( # noqa: C901
529529
def predict( # type: ignore[override, return]
530530
self,
531531
X_wide: Optional[np.ndarray] = None,
532-
X_tab: Optional[np.ndarray] = None,
532+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
533533
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
534534
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
535535
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
@@ -585,7 +585,7 @@ def predict( # type: ignore[override, return]
585585
def predict_uncertainty( # type: ignore[return]
586586
self,
587587
X_wide: Optional[np.ndarray] = None,
588-
X_tab: Optional[np.ndarray] = None,
588+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
589589
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
590590
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
591591
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
@@ -682,7 +682,7 @@ def predict_uncertainty( # type: ignore[return]
682682
def predict_proba( # type: ignore[override, return] # noqa: C901
683683
self,
684684
X_wide: Optional[np.ndarray] = None,
685-
X_tab: Optional[np.ndarray] = None,
685+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
686686
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
687687
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
688688
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
@@ -942,7 +942,7 @@ def _eval_step(
942942
X[k] = v.to(self.device)
943943
y = (
944944
target.view(-1, 1).float()
945-
if self.method not in ["multiclass", "qregression"]
945+
if self.method not in ["multiclass", "qregression", "multitarget"]
946946
else target
947947
)
948948
y = y.to(self.device)
@@ -971,14 +971,15 @@ def _get_score(self, y_pred, y):
971971
score = self.metric(y_pred, y)
972972
if self.method == "multiclass":
973973
score = self.metric(F.softmax(y_pred, dim=1), y)
974+
# TO DO: handle multitarget
974975
return score
975976
else:
976977
return None
977978

978979
def _predict( # type: ignore[override, return] # noqa: C901
979980
self,
980981
X_wide: Optional[np.ndarray] = None,
981-
X_tab: Optional[np.ndarray] = None,
982+
X_tab: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
982983
X_text: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
983984
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
984985
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,

0 commit comments

Comments
 (0)