Skip to content

Commit 7cc4dad

Browse files
committed
Adjusted docs when needed
1 parent 929198c commit 7cc4dad

File tree

6 files changed

+26
-127
lines changed

6 files changed

+26
-127
lines changed

examples/scripts/mutil_tabular_components.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

pytorch_widedeep/models/model_fusion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ class ModelFuser(BaseWDModelComponent):
2020
List of models whose outputs will be fused
2121
fusion_method: Union[str, List[str]]
2222
Method to fuse the output of the models. It can be one of
23-
['concatenate', 'mean', 'max', 'sum', 'mult', 'head'] or a list of
24-
those. If a list is provided the output of the models will be fused
25-
using all the methods in the list and the final output will be the
26-
concatenation of the outputs of each method
23+
['concatenate', 'mean', 'max', 'sum', 'mult', 'dot', 'head'] or a
24+
list of those, but 'dot'. If a list is provided the output of the
25+
models will be fused using all the methods in the list and the final
26+
output will be the concatenation of the outputs of each method
2727
projection_method: Optional[str]
2828
If the fusion_method is not 'concatenate', this parameter will
2929
determine how to project the output of the models to a common

pytorch_widedeep/models/wide_deep.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class WideDeep(nn.Module):
5353
deeptabular: BaseWDModelComponent, Optional, default = None
5454
Currently this library implements a number of possible architectures
5555
for the `deeptabular` component. See the documenation of the
56-
package.
56+
package. Note that `deeptabular` can be a list of models. This is
57+
useful when using multiple tabular inputs (e.g. for example in the
58+
context of a two-tower model for recommendation systems)
5759
deeptext: BaseWDModelComponent | List[BaseWDModelComponent], Optional, default = None
5860
Currently this library implements a number of possible architectures
5961
for the `deeptext` component. See the documenation of the

pytorch_widedeep/training/_trainer_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def tabular_train_val_split(
4747
X_val: Optional[np.ndarray] = None,
4848
y_val: Optional[np.ndarray] = None,
4949
val_split: Optional[float] = None,
50-
):
50+
) -> Tuple[TensorDataset, Optional[TensorDataset]]:
5151
r"""
5252
Function to create the train/val split for the BayesianTrainer where only
5353
tabular data is used
@@ -86,7 +86,7 @@ def tabular_train_val_split(
8686
torch.from_numpy(y_val),
8787
)
8888
elif val_split is not None:
89-
y_tr, y_val, idx_tr, idx_val = train_test_split(
89+
y_tr, y_val, idx_tr, idx_val = train_test_split( # type: ignore
9090
y,
9191
np.arange(len(y)),
9292
test_size=val_split,
@@ -140,13 +140,13 @@ def wd_train_val_split( # noqa: C901
140140
random seed to be used during train/val split
141141
method: str
142142
'regression', 'binary' or 'multiclass'
143-
X_wide: np.ndarray, Optional, default = None
143+
X_wide: np.ndaaray, Optional, default = None
144144
wide dataset
145-
X_tab: np.ndarray, Optional, default = None
145+
X_tab: np.ndarray or List[np.ndarray], Optional, default = None
146146
tabular dataset (categorical and continuous features)
147-
X_img: np.ndarray, Optional, default = None
147+
X_img: np.ndarray or List[np.ndarray], Optional, default = None
148148
image dataset
149-
X_text: np.ndarray, Optional, default = None
149+
X_text: np.ndarray or List[np.ndarray], Optional, default = None
150150
text dataset
151151
X_val: Dict, Optional, default = None
152152
Dict with the validation set, where the keys are the component names

pytorch_widedeep/training/_wd_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class WideDeepDataset(Dataset):
1717
----------
1818
X_wide: np.ndarray
1919
wide input
20-
X_tab: np.ndarray
20+
X_tab: np.ndarray or List[np.ndarray]
2121
deeptabular input
2222
X_text: np.ndarray or List[np.ndarray]
2323
deeptext input

pytorch_widedeep/training/trainer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,9 @@ def fit( # noqa: C901
298298
See `pytorch_widedeep.preprocessing.WidePreprocessor`
299299
X_tab: np.ndarray, Optional. default=None
300300
Input for the `deeptabular` model component.
301-
See `pytorch_widedeep.preprocessing.TabPreprocessor`
301+
See `pytorch_widedeep.preprocessing.TabPreprocessor`. If multiple
302+
tabular models are used for different columns, this should be a
303+
list of numpy arrays
302304
X_text: Union[np.ndarray, List[np.ndarray]], Optional. default=None
303305
Input for the `deeptext` model component.
304306
See `pytorch_widedeep.preprocessing.TextPreprocessor`.
@@ -547,13 +549,13 @@ def predict( # type: ignore[override, return]
547549
X_wide: np.ndarray, Optional. default=None
548550
Input for the `wide` model component.
549551
See `pytorch_widedeep.preprocessing.WidePreprocessor`
550-
X_tab: np.ndarray, Optional. default=None
552+
X_tab: np.ndarray or List[np.ndarray], Optional. default=None
551553
Input for the `deeptabular` model component.
552554
See `pytorch_widedeep.preprocessing.TabPreprocessor`
553-
X_text: np.ndarray, Optional. default=None
555+
X_text: np.ndarray or List[np.ndarray], Optional. default=None
554556
Input for the `deeptext` model component.
555557
See `pytorch_widedeep.preprocessing.TextPreprocessor`
556-
X_img: np.ndarray, Optional. default=None
558+
X_img: np.ndarray or List[np.ndarray], Optional. default=None
557559
Input for the `deepimage` model component.
558560
See `pytorch_widedeep.preprocessing.ImagePreprocessor`
559561
X_test: Dict, Optional. default=None
@@ -606,13 +608,13 @@ def predict_uncertainty( # type: ignore[return]
606608
X_wide: np.ndarray, Optional. default=None
607609
Input for the `wide` model component.
608610
See `pytorch_widedeep.preprocessing.WidePreprocessor`
609-
X_tab: np.ndarray, Optional. default=None
611+
X_tab: np.ndarray or List[np.ndarray], Optional. default=None
610612
Input for the `deeptabular` model component.
611613
See `pytorch_widedeep.preprocessing.TabPreprocessor`
612-
X_text: np.ndarray, Optional. default=None
614+
X_text: np.ndarray or List[np.ndarray], Optional. default=None
613615
Input for the `deeptext` model component.
614616
See `pytorch_widedeep.preprocessing.TextPreprocessor`
615-
X_img: np.ndarray, Optional. default=None
617+
X_img: np.ndarray or List[np.ndarray], Optional. default=None
616618
Input for the `deepimage` model component.
617619
See `pytorch_widedeep.preprocessing.ImagePreprocessor`
618620
X_test: Dict, Optional. default=None
@@ -700,13 +702,13 @@ def predict_proba( # type: ignore[override, return] # noqa: C901
700702
X_wide: np.ndarray, Optional. default=None
701703
Input for the `wide` model component.
702704
See `pytorch_widedeep.preprocessing.WidePreprocessor`
703-
X_tab: np.ndarray, Optional. default=None
705+
X_tab: np.ndarray or List[np.ndarray], Optional. default=None
704706
Input for the `deeptabular` model component.
705707
See `pytorch_widedeep.preprocessing.TabPreprocessor`
706-
X_text: np.ndarray, Optional. default=None
708+
X_text: np.ndarray or List[np.ndarray], Optional. default=None
707709
Input for the `deeptext` model component.
708710
See `pytorch_widedeep.preprocessing.TextPreprocessor`
709-
X_img: np.ndarray, Optional. default=None
711+
X_img: np.ndarray or List[np.ndarray], Optional. default=None
710712
Input for the `deepimage` model component.
711713
See `pytorch_widedeep.preprocessing.ImagePreprocessor`
712714
X_test: Dict, Optional. default=None

0 commit comments

Comments
 (0)