10
10
import pytest
11
11
12
12
from pytorch_widedeep import Trainer
13
- from pytorch_widedeep .models import TabMlp , BasicRNN , WideDeep , ModelFuser
13
+ from pytorch_widedeep .models import (
14
+ TabMlp ,
15
+ TabNet ,
16
+ BasicRNN ,
17
+ WideDeep ,
18
+ ModelFuser ,
19
+ )
14
20
from pytorch_widedeep .metrics import F1Score , Accuracy
15
21
from pytorch_widedeep .callbacks import LRHistory
16
22
from pytorch_widedeep .initializers import XavierNormal , KaimingNormal
@@ -569,7 +575,8 @@ def test_model_fusion_projection_methods(projection_method):
569
575
assert out .shape [1 ] == proj_dim == models_fuser .output_dim
570
576
571
577
572
- def test_model_fusion_full_process ():
578
+ @pytest .mark .parametrize ("head_type" , [None , "via_params" , "custom" ])
579
+ def test_full_process_with_fusion (head_type ):
573
580
574
581
fused_tab_model = ModelFuser (
575
582
models = [tab_mlp_user , tab_mlp_item ],
@@ -583,10 +590,87 @@ def test_model_fusion_full_process():
583
590
projection_method = "min" ,
584
591
)
585
592
593
+ if head_type == "via_params" :
594
+ head_hidden_dims = [fused_tab_model .output_dim + fused_text_model .output_dim , 8 ]
595
+ custom_head = None
596
+ elif head_type == "custom" :
597
+ head_hidden_dims = None
598
+ custom_head = CustomHead (
599
+ fused_tab_model .output_dim + fused_text_model .output_dim , 8
600
+ )
601
+ else :
602
+ head_hidden_dims = None
603
+ custom_head = None
604
+
586
605
model = WideDeep (
587
606
deeptabular = fused_tab_model ,
588
607
deeptext = fused_text_model ,
589
608
pred_dim = 1 ,
609
+ head_hidden_dims = head_hidden_dims ,
610
+ deephead = custom_head ,
611
+ )
612
+
613
+ n_epochs = 2
614
+ trainer = Trainer (
615
+ model ,
616
+ objective = "binary" ,
617
+ verbose = 0 ,
618
+ )
619
+
620
+ X_train = {
621
+ "X_tab" : [X_tab_user_tr , X_tab_item_tr ],
622
+ "X_text" : [X_text_review_tr , X_text_description_tr ],
623
+ "target" : train_df ["purchased" ].values ,
624
+ }
625
+ X_val = {
626
+ "X_tab" : [X_tab_user_val , X_tab_item_val ],
627
+ "X_text" : [X_text_review_val , X_text_description_val ],
628
+ "target" : valid_df ["purchased" ].values ,
629
+ }
630
+ trainer .fit (
631
+ X_train = X_train ,
632
+ X_val = X_val ,
633
+ n_epochs = n_epochs ,
634
+ batch_size = 4 ,
635
+ )
636
+
637
+ # weak assertion, but anyway...
638
+ assert len (trainer .history ["train_loss" ]) == n_epochs
639
+
640
+
641
+ @pytest .mark .parametrize ("head_type" , [None , "via_params" , "custom" ])
642
+ def test_full_process_without_fusion (head_type ):
643
+
644
+ # the 4 models to be combined are tab_mlp_user, tab_mlp_item, rnn_reviews,
645
+ # rnn_descriptions
646
+ if head_type == "via_params" :
647
+ head_hidden_dims = [
648
+ tab_mlp_user .output_dim
649
+ + tab_mlp_item .output_dim
650
+ + rnn_reviews .output_dim
651
+ + rnn_descriptions .output_dim ,
652
+ 8 ,
653
+ ]
654
+ custom_head = None
655
+ elif head_type == "custom" :
656
+ head_hidden_dims = None
657
+ custom_head = CustomHead (
658
+ tab_mlp_user .output_dim
659
+ + tab_mlp_item .output_dim
660
+ + rnn_reviews .output_dim
661
+ + rnn_descriptions .output_dim ,
662
+ 8 ,
663
+ )
664
+ else :
665
+ head_hidden_dims = None
666
+ custom_head = None
667
+
668
+ model = WideDeep (
669
+ deeptabular = [tab_mlp_user , tab_mlp_item ],
670
+ deeptext = [rnn_reviews , rnn_descriptions ],
671
+ pred_dim = 1 ,
672
+ head_hidden_dims = head_hidden_dims ,
673
+ deephead = custom_head ,
590
674
)
591
675
592
676
n_epochs = 2
@@ -615,3 +699,30 @@ def test_model_fusion_full_process():
615
699
616
700
# weak assertion, but anyway...
617
701
assert len (trainer .history ["train_loss" ]) == n_epochs
702
+
703
+
704
+ @pytest .mark .parametrize ("fuse_models" , [True , False ])
705
+ def test_catch_tabnet_error (fuse_models ):
706
+
707
+ tabnet_user = TabNet (
708
+ column_idx = tab_preprocessor_user .column_idx ,
709
+ cat_embed_input = tab_preprocessor_user .cat_embed_input ,
710
+ continuous_cols = tab_preprocessor_user .continuous_cols ,
711
+ )
712
+
713
+ tab_mlp_item = TabMlp (
714
+ column_idx = tab_preprocessor_item .column_idx ,
715
+ cat_embed_input = tab_preprocessor_item .cat_embed_input ,
716
+ continuous_cols = tab_preprocessor_item .continuous_cols ,
717
+ )
718
+
719
+ if fuse_models :
720
+ with pytest .raises (ValueError ):
721
+ fused_model = ModelFuser ( # noqa: F841
722
+ models = [tabnet_user , tab_mlp_item ],
723
+ fusion_method = "mean" ,
724
+ projection_method = "max" ,
725
+ )
726
+ else :
727
+ with pytest .raises (ValueError ):
728
+ model = WideDeep (deeptabular = [tabnet_user , tab_mlp_item ]) # noqa: F841
0 commit comments