@@ -694,7 +694,8 @@ def test_model_fusion_projection_methods(projection_method):
694
694
assert out .shape [1 ] == proj_dim == models_fuser .output_dim
695
695
696
696
697
- def test_model_fusion_full_process ():
697
+ @pytest .mark .parametrize ("head_type" , [None , "via_params" , "custom" ])
698
+ def test_full_process_with_fusion (head_type ):
698
699
699
700
fused_text_model = ModelFuser (
700
701
models = [rnn_1 , rnn_2 ],
@@ -708,17 +709,102 @@ def test_model_fusion_full_process():
708
709
projection_method = "max" ,
709
710
)
710
711
712
+ if head_type == "via_params" :
713
+ head_hidden_dims = [
714
+ fused_text_model .output_dim ,
715
+ fused_image_model .output_dim + tab_mlp .output_dim ,
716
+ 8 ,
717
+ ]
718
+ custom_head = None
719
+ elif head_type == "custom" :
720
+ custom_head = CustomHead (
721
+ fused_text_model .output_dim
722
+ + fused_image_model .output_dim
723
+ + tab_mlp .output_dim ,
724
+ 8 ,
725
+ )
726
+ head_hidden_dims = None
727
+ else :
728
+ custom_head = None
729
+ head_hidden_dims = None
730
+
711
731
model = WideDeep (
712
732
deeptabular = tab_mlp ,
713
733
deeptext = fused_text_model ,
714
734
deepimage = fused_image_model ,
735
+ head_hidden_dims = head_hidden_dims ,
736
+ deephead = custom_head ,
737
+ pred_dim = 1 ,
738
+ )
739
+
740
+ n_epochs = 2
741
+ trainer = Trainer (
742
+ model ,
743
+ objective = "binary" ,
744
+ verbose = 0 ,
745
+ )
746
+
747
+ X_train = {
748
+ "X_tab" : X_tab_tr ,
749
+ "X_text" : [X_text_tr_1 , X_text_tr_2 ],
750
+ "X_img" : [X_img_tr_1 , X_img_tr_2 ],
751
+ "target" : train_df ["target" ].values ,
752
+ }
753
+ X_val = {
754
+ "X_tab" : X_tab_val ,
755
+ "X_text" : [X_text_val_1 , X_text_val_2 ],
756
+ "X_img" : [X_img_val_1 , X_img_val_2 ],
757
+ "target" : valid_df ["target" ].values ,
758
+ }
759
+ trainer .fit (
760
+ X_train = X_train ,
761
+ X_val = X_val ,
762
+ n_epochs = n_epochs ,
763
+ batch_size = 4 ,
764
+ )
765
+
766
+ # weak assertion, but anyway...
767
+ assert len (trainer .history ["train_loss" ]) == n_epochs
768
+
769
+
770
+ @pytest .mark .parametrize ("head_type" , [None , "via_params" , "custom" ])
771
+ def test_full_process_without_fusion (head_type ):
772
+
773
+ if head_type == "via_params" :
774
+ head_hidden_dims = [
775
+ rnn_1 .output_dim + rnn_2 .output_dim ,
776
+ vision_1 .output_dim + vision_2 .output_dim + tab_mlp .output_dim ,
777
+ 8 ,
778
+ ]
779
+ custom_head = None
780
+ elif head_type == "custom" :
781
+ custom_head = CustomHead (
782
+ rnn_1 .output_dim
783
+ + rnn_2 .output_dim
784
+ + vision_1 .output_dim
785
+ + vision_2 .output_dim
786
+ + tab_mlp .output_dim ,
787
+ 8 ,
788
+ )
789
+ head_hidden_dims = None
790
+ else :
791
+ custom_head = None
792
+ head_hidden_dims = None
793
+
794
+ model = WideDeep (
795
+ deeptabular = tab_mlp ,
796
+ deeptext = [rnn_1 , rnn_2 ],
797
+ deepimage = [vision_1 , vision_2 ],
798
+ head_hidden_dims = head_hidden_dims ,
799
+ deephead = custom_head ,
715
800
pred_dim = 1 ,
716
801
)
717
802
718
803
n_epochs = 2
719
804
trainer = Trainer (
720
805
model ,
721
806
objective = "binary" ,
807
+ verbose = 0 ,
722
808
)
723
809
724
810
X_train = {
@@ -738,7 +824,6 @@ def test_model_fusion_full_process():
738
824
X_val = X_val ,
739
825
n_epochs = n_epochs ,
740
826
batch_size = 4 ,
741
- verbose = 1 ,
742
827
)
743
828
744
829
# weak assertion, but anyway...
0 commit comments