Skip to content

Commit 07155ec

Browse files
committed
Fixed a bug related to multiple tabular components with a FC head. Increased test coverage
1 parent f244689 commit 07155ec

File tree

3 files changed

+137
-6
lines changed

3 files changed

+137
-6
lines changed

pytorch_widedeep/models/model_fusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch import nn
33

4+
from pytorch_widedeep.models import TabNet
45
from pytorch_widedeep.wdtypes import List, Union, Tensor, Literal, Optional
56
from pytorch_widedeep.models.tabular.mlp._layers import MLP
67
from pytorch_widedeep.models._base_wd_model_component import (
@@ -300,6 +301,13 @@ def output_dim(self) -> int:
300301
return output_dim
301302

302303
def check_input_parameters(self): # noqa: C901
304+
305+
if any(isinstance(model, TabNet) for model in self.models):
306+
raise ValueError(
307+
"TabNet is not supported in ModelFuser. "
308+
"Please, use another model for tabular data"
309+
)
310+
303311
if isinstance(self.fusion_method, str):
304312
if not any(
305313
x == self.fusion_method

pytorch_widedeep/models/wide_deep.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def _forward_component_with_head(
400400
) -> Tensor:
401401
if isinstance(component, nn.ModuleList):
402402
component_out = torch.cat( # type: ignore[call-overload]
403-
[cp(X[component_type]) for cp in component], axis=1
403+
[cp(X[component_type][i]) for i, cp in enumerate(component)], axis=1
404404
)
405405
else:
406406
component_out = component(X[component_type])
@@ -547,11 +547,23 @@ def _check_inputs( # noqa: C901
547547
deephead_inp_feat = next(deephead.parameters()).size(1)
548548
output_dim = 0
549549
if deeptabular is not None:
550-
output_dim += deeptabular.output_dim
550+
if isinstance(deeptabular, list):
551+
for dt in deeptabular:
552+
output_dim += dt.output_dim
553+
else:
554+
output_dim += deeptabular.output_dim
551555
if deeptext is not None:
552-
output_dim += deeptext.output_dim
556+
if isinstance(deeptext, list):
557+
for dt in deeptext:
558+
output_dim += dt.output_dim
559+
else:
560+
output_dim += deeptext.output_dim
553561
if deepimage is not None:
554-
output_dim += deepimage.output_dim
562+
if isinstance(deepimage, list):
563+
for di in deepimage:
564+
output_dim += di.output_dim
565+
else:
566+
output_dim += deepimage.output_dim
555567
if deephead_inp_feat != output_dim:
556568
warnings.warn(
557569
"A custom 'deephead' is used and it seems that the input features "

tests/test_multi_model_and_mutil_data/test_multi_tab_and_text_components.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
import pytest
1111

1212
from pytorch_widedeep import Trainer
13-
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser
13+
from pytorch_widedeep.models import (
14+
TabMlp,
15+
TabNet,
16+
BasicRNN,
17+
WideDeep,
18+
ModelFuser,
19+
)
1420
from pytorch_widedeep.metrics import F1Score, Accuracy
1521
from pytorch_widedeep.callbacks import LRHistory
1622
from pytorch_widedeep.initializers import XavierNormal, KaimingNormal
@@ -569,7 +575,8 @@ def test_model_fusion_projection_methods(projection_method):
569575
assert out.shape[1] == proj_dim == models_fuser.output_dim
570576

571577

572-
def test_model_fusion_full_process():
578+
@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
579+
def test_full_process_with_fusion(head_type):
573580

574581
fused_tab_model = ModelFuser(
575582
models=[tab_mlp_user, tab_mlp_item],
@@ -583,10 +590,87 @@ def test_model_fusion_full_process():
583590
projection_method="min",
584591
)
585592

593+
if head_type == "via_params":
594+
head_hidden_dims = [fused_tab_model.output_dim + fused_text_model.output_dim, 8]
595+
custom_head = None
596+
elif head_type == "custom":
597+
head_hidden_dims = None
598+
custom_head = CustomHead(
599+
fused_tab_model.output_dim + fused_text_model.output_dim, 8
600+
)
601+
else:
602+
head_hidden_dims = None
603+
custom_head = None
604+
586605
model = WideDeep(
587606
deeptabular=fused_tab_model,
588607
deeptext=fused_text_model,
589608
pred_dim=1,
609+
head_hidden_dims=head_hidden_dims,
610+
deephead=custom_head,
611+
)
612+
613+
n_epochs = 2
614+
trainer = Trainer(
615+
model,
616+
objective="binary",
617+
verbose=0,
618+
)
619+
620+
X_train = {
621+
"X_tab": [X_tab_user_tr, X_tab_item_tr],
622+
"X_text": [X_text_review_tr, X_text_description_tr],
623+
"target": train_df["purchased"].values,
624+
}
625+
X_val = {
626+
"X_tab": [X_tab_user_val, X_tab_item_val],
627+
"X_text": [X_text_review_val, X_text_description_val],
628+
"target": valid_df["purchased"].values,
629+
}
630+
trainer.fit(
631+
X_train=X_train,
632+
X_val=X_val,
633+
n_epochs=n_epochs,
634+
batch_size=4,
635+
)
636+
637+
# weak assertion, but anyway...
638+
assert len(trainer.history["train_loss"]) == n_epochs
639+
640+
641+
@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
642+
def test_full_process_without_fusion(head_type):
643+
644+
# the 4 models to be combined are tab_mlp_user, tab_mlp_item, rnn_reviews,
645+
# rnn_descriptions
646+
if head_type == "via_params":
647+
head_hidden_dims = [
648+
tab_mlp_user.output_dim
649+
+ tab_mlp_item.output_dim
650+
+ rnn_reviews.output_dim
651+
+ rnn_descriptions.output_dim,
652+
8,
653+
]
654+
custom_head = None
655+
elif head_type == "custom":
656+
head_hidden_dims = None
657+
custom_head = CustomHead(
658+
tab_mlp_user.output_dim
659+
+ tab_mlp_item.output_dim
660+
+ rnn_reviews.output_dim
661+
+ rnn_descriptions.output_dim,
662+
8,
663+
)
664+
else:
665+
head_hidden_dims = None
666+
custom_head = None
667+
668+
model = WideDeep(
669+
deeptabular=[tab_mlp_user, tab_mlp_item],
670+
deeptext=[rnn_reviews, rnn_descriptions],
671+
pred_dim=1,
672+
head_hidden_dims=head_hidden_dims,
673+
deephead=custom_head,
590674
)
591675

592676
n_epochs = 2
@@ -615,3 +699,30 @@ def test_model_fusion_full_process():
615699

616700
# weak assertion, but anyway...
617701
assert len(trainer.history["train_loss"]) == n_epochs
702+
703+
704+
@pytest.mark.parametrize("fuse_models", [True, False])
705+
def test_catch_tabnet_error(fuse_models):
706+
707+
tabnet_user = TabNet(
708+
column_idx=tab_preprocessor_user.column_idx,
709+
cat_embed_input=tab_preprocessor_user.cat_embed_input,
710+
continuous_cols=tab_preprocessor_user.continuous_cols,
711+
)
712+
713+
tab_mlp_item = TabMlp(
714+
column_idx=tab_preprocessor_item.column_idx,
715+
cat_embed_input=tab_preprocessor_item.cat_embed_input,
716+
continuous_cols=tab_preprocessor_item.continuous_cols,
717+
)
718+
719+
if fuse_models:
720+
with pytest.raises(ValueError):
721+
fused_model = ModelFuser( # noqa: F841
722+
models=[tabnet_user, tab_mlp_item],
723+
fusion_method="mean",
724+
projection_method="max",
725+
)
726+
else:
727+
with pytest.raises(ValueError):
728+
model = WideDeep(deeptabular=[tabnet_user, tab_mlp_item]) # noqa: F841

0 commit comments

Comments
 (0)