Skip to content

Commit 16922ce

Browse files
committed
A further increased in coverage of edge cases
1 parent 07155ec commit 16922ce

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

tests/test_multi_model_and_mutil_data/test_multi_tab_and_text_components.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,7 @@ def test_full_process_without_fusion(head_type):
677677
trainer = Trainer(
678678
model,
679679
objective="binary",
680+
verbose=0,
680681
)
681682

682683
X_train = {
@@ -694,7 +695,6 @@ def test_full_process_without_fusion(head_type):
694695
X_val=X_val,
695696
n_epochs=n_epochs,
696697
batch_size=4,
697-
verbose=1,
698698
)
699699

700700
# weak assertion, but anyway...

tests/test_multi_model_and_mutil_data/test_multi_text_or_image_cols.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,8 @@ def test_model_fusion_projection_methods(projection_method):
694694
assert out.shape[1] == proj_dim == models_fuser.output_dim
695695

696696

697-
def test_model_fusion_full_process():
697+
@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
698+
def test_full_process_with_fusion(head_type):
698699

699700
fused_text_model = ModelFuser(
700701
models=[rnn_1, rnn_2],
@@ -708,17 +709,102 @@ def test_model_fusion_full_process():
708709
projection_method="max",
709710
)
710711

712+
if head_type == "via_params":
713+
head_hidden_dims = [
714+
fused_text_model.output_dim,
715+
fused_image_model.output_dim + tab_mlp.output_dim,
716+
8,
717+
]
718+
custom_head = None
719+
elif head_type == "custom":
720+
custom_head = CustomHead(
721+
fused_text_model.output_dim
722+
+ fused_image_model.output_dim
723+
+ tab_mlp.output_dim,
724+
8,
725+
)
726+
head_hidden_dims = None
727+
else:
728+
custom_head = None
729+
head_hidden_dims = None
730+
711731
model = WideDeep(
712732
deeptabular=tab_mlp,
713733
deeptext=fused_text_model,
714734
deepimage=fused_image_model,
735+
head_hidden_dims=head_hidden_dims,
736+
deephead=custom_head,
737+
pred_dim=1,
738+
)
739+
740+
n_epochs = 2
741+
trainer = Trainer(
742+
model,
743+
objective="binary",
744+
verbose=0,
745+
)
746+
747+
X_train = {
748+
"X_tab": X_tab_tr,
749+
"X_text": [X_text_tr_1, X_text_tr_2],
750+
"X_img": [X_img_tr_1, X_img_tr_2],
751+
"target": train_df["target"].values,
752+
}
753+
X_val = {
754+
"X_tab": X_tab_val,
755+
"X_text": [X_text_val_1, X_text_val_2],
756+
"X_img": [X_img_val_1, X_img_val_2],
757+
"target": valid_df["target"].values,
758+
}
759+
trainer.fit(
760+
X_train=X_train,
761+
X_val=X_val,
762+
n_epochs=n_epochs,
763+
batch_size=4,
764+
)
765+
766+
# weak assertion, but anyway...
767+
assert len(trainer.history["train_loss"]) == n_epochs
768+
769+
770+
@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
771+
def test_full_process_without_fusion(head_type):
772+
773+
if head_type == "via_params":
774+
head_hidden_dims = [
775+
rnn_1.output_dim + rnn_2.output_dim,
776+
vision_1.output_dim + vision_2.output_dim + tab_mlp.output_dim,
777+
8,
778+
]
779+
custom_head = None
780+
elif head_type == "custom":
781+
custom_head = CustomHead(
782+
rnn_1.output_dim
783+
+ rnn_2.output_dim
784+
+ vision_1.output_dim
785+
+ vision_2.output_dim
786+
+ tab_mlp.output_dim,
787+
8,
788+
)
789+
head_hidden_dims = None
790+
else:
791+
custom_head = None
792+
head_hidden_dims = None
793+
794+
model = WideDeep(
795+
deeptabular=tab_mlp,
796+
deeptext=[rnn_1, rnn_2],
797+
deepimage=[vision_1, vision_2],
798+
head_hidden_dims=head_hidden_dims,
799+
deephead=custom_head,
715800
pred_dim=1,
716801
)
717802

718803
n_epochs = 2
719804
trainer = Trainer(
720805
model,
721806
objective="binary",
807+
verbose=0,
722808
)
723809

724810
X_train = {
@@ -738,7 +824,6 @@ def test_model_fusion_full_process():
738824
X_val=X_val,
739825
n_epochs=n_epochs,
740826
batch_size=4,
741-
verbose=1,
742827
)
743828

744829
# weak assertion, but anyway...

0 commit comments

Comments
 (0)