@@ -694,7 +694,8 @@ def test_model_fusion_projection_methods(projection_method):
694694 assert out .shape [1 ] == proj_dim == models_fuser .output_dim
695695
696696
697- def test_model_fusion_full_process ():
697+ @pytest .mark .parametrize ("head_type" , [None , "via_params" , "custom" ])
698+ def test_full_process_with_fusion (head_type ):
698699
699700 fused_text_model = ModelFuser (
700701 models = [rnn_1 , rnn_2 ],
@@ -708,17 +709,102 @@ def test_model_fusion_full_process():
708709 projection_method = "max" ,
709710 )
710711
712+ if head_type == "via_params" :
713+ head_hidden_dims = [
714+ fused_text_model .output_dim ,
715+ fused_image_model .output_dim + tab_mlp .output_dim ,
716+ 8 ,
717+ ]
718+ custom_head = None
719+ elif head_type == "custom" :
720+ custom_head = CustomHead (
721+ fused_text_model .output_dim
722+ + fused_image_model .output_dim
723+ + tab_mlp .output_dim ,
724+ 8 ,
725+ )
726+ head_hidden_dims = None
727+ else :
728+ custom_head = None
729+ head_hidden_dims = None
730+
711731 model = WideDeep (
712732 deeptabular = tab_mlp ,
713733 deeptext = fused_text_model ,
714734 deepimage = fused_image_model ,
735+ head_hidden_dims = head_hidden_dims ,
736+ deephead = custom_head ,
737+ pred_dim = 1 ,
738+ )
739+
740+ n_epochs = 2
741+ trainer = Trainer (
742+ model ,
743+ objective = "binary" ,
744+ verbose = 0 ,
745+ )
746+
747+ X_train = {
748+ "X_tab" : X_tab_tr ,
749+ "X_text" : [X_text_tr_1 , X_text_tr_2 ],
750+ "X_img" : [X_img_tr_1 , X_img_tr_2 ],
751+ "target" : train_df ["target" ].values ,
752+ }
753+ X_val = {
754+ "X_tab" : X_tab_val ,
755+ "X_text" : [X_text_val_1 , X_text_val_2 ],
756+ "X_img" : [X_img_val_1 , X_img_val_2 ],
757+ "target" : valid_df ["target" ].values ,
758+ }
759+ trainer .fit (
760+ X_train = X_train ,
761+ X_val = X_val ,
762+ n_epochs = n_epochs ,
763+ batch_size = 4 ,
764+ )
765+
766+ # weak assertion, but anyway...
767+ assert len (trainer .history ["train_loss" ]) == n_epochs
768+
769+
770+ @pytest .mark .parametrize ("head_type" , [None , "via_params" , "custom" ])
771+ def test_full_process_without_fusion (head_type ):
772+
773+ if head_type == "via_params" :
774+ head_hidden_dims = [
775+ rnn_1 .output_dim + rnn_2 .output_dim ,
776+ vision_1 .output_dim + vision_2 .output_dim + tab_mlp .output_dim ,
777+ 8 ,
778+ ]
779+ custom_head = None
780+ elif head_type == "custom" :
781+ custom_head = CustomHead (
782+ rnn_1 .output_dim
783+ + rnn_2 .output_dim
784+ + vision_1 .output_dim
785+ + vision_2 .output_dim
786+ + tab_mlp .output_dim ,
787+ 8 ,
788+ )
789+ head_hidden_dims = None
790+ else :
791+ custom_head = None
792+ head_hidden_dims = None
793+
794+ model = WideDeep (
795+ deeptabular = tab_mlp ,
796+ deeptext = [rnn_1 , rnn_2 ],
797+ deepimage = [vision_1 , vision_2 ],
798+ head_hidden_dims = head_hidden_dims ,
799+ deephead = custom_head ,
715800 pred_dim = 1 ,
716801 )
717802
718803 n_epochs = 2
719804 trainer = Trainer (
720805 model ,
721806 objective = "binary" ,
807+ verbose = 0 ,
722808 )
723809
724810 X_train = {
@@ -738,7 +824,6 @@ def test_model_fusion_full_process():
738824 X_val = X_val ,
739825 n_epochs = n_epochs ,
740826 batch_size = 4 ,
741- verbose = 1 ,
742827 )
743828
744829 # weak assertion, but anyway...
0 commit comments