Skip to content

Commit 220eb3f

Browse files
authored
Merge pull request #226 from jrzaurin/multiple_tab_components
Multiple tab components
2 parents deb4f2e + 16922ce commit 220eb3f

File tree

274 files changed

+15113
-17716
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

274 files changed

+15113
-17716
lines changed

README.md

+100-2
Original file line numberDiff line numberDiff line change
@@ -587,13 +587,111 @@ trainer.fit(
587587
)
588588
```
589589

590-
**7. Tabular with a multi-target loss**
590+
**7. A two-tower model**
591+
592+
This is a popular model in the context of recommendation systems. Let's say we
593+
have a tabular dataset formed my triples (user features, item features,
594+
target). We can create a two-tower model where the user and item features are
595+
passed through two separate models and then "fused" via a dot product.
596+
597+
<p align="center">
598+
<img width="350" src="docs/figures/arch_7.png">
599+
</p>
600+
601+
602+
```python
603+
import numpy as np
604+
import pandas as pd
605+
606+
from pytorch_widedeep import Trainer
607+
from pytorch_widedeep.preprocessing import TabPreprocessor
608+
from pytorch_widedeep.models import TabMlp, WideDeep, ModelFuser
609+
610+
# Let's create the interaction dataset
611+
# user_features dataframe
612+
np.random.seed(42)
613+
user_ids = np.arange(1, 101)
614+
ages = np.random.randint(18, 60, size=100)
615+
genders = np.random.choice(["male", "female"], size=100)
616+
locations = np.random.choice(["city_a", "city_b", "city_c", "city_d"], size=100)
617+
user_features = pd.DataFrame(
618+
{"id": user_ids, "age": ages, "gender": genders, "location": locations}
619+
)
620+
621+
# item_features dataframe
622+
item_ids = np.arange(1, 101)
623+
prices = np.random.uniform(10, 500, size=100).round(2)
624+
colors = np.random.choice(["red", "blue", "green", "black"], size=100)
625+
categories = np.random.choice(["electronics", "clothing", "home", "toys"], size=100)
626+
627+
item_features = pd.DataFrame(
628+
{"id": item_ids, "price": prices, "color": colors, "category": categories}
629+
)
630+
631+
# Interactions dataframe
632+
interaction_user_ids = np.random.choice(user_ids, size=1000)
633+
interaction_item_ids = np.random.choice(item_ids, size=1000)
634+
purchased = np.random.choice([0, 1], size=1000, p=[0.7, 0.3])
635+
interactions = pd.DataFrame(
636+
{
637+
"user_id": interaction_user_ids,
638+
"item_id": interaction_item_ids,
639+
"purchased": purchased,
640+
}
641+
)
642+
user_item_purchased = interactions.merge(
643+
user_features, left_on="user_id", right_on="id"
644+
).merge(item_features, left_on="item_id", right_on="id")
645+
646+
# Users
647+
tab_preprocessor_user = TabPreprocessor(
648+
cat_embed_cols=["gender", "location"],
649+
continuous_cols=["age"],
650+
)
651+
X_user = tab_preprocessor_user.fit_transform(user_item_purchased)
652+
tab_mlp_user = TabMlp(
653+
column_idx=tab_preprocessor_user.column_idx,
654+
cat_embed_input=tab_preprocessor_user.cat_embed_input,
655+
continuous_cols=["age"],
656+
mlp_hidden_dims=[16, 8],
657+
mlp_dropout=[0.2, 0.2],
658+
)
659+
660+
# Items
661+
tab_preprocessor_item = TabPreprocessor(
662+
cat_embed_cols=["color", "category"],
663+
continuous_cols=["price"],
664+
)
665+
X_item = tab_preprocessor_item.fit_transform(user_item_purchased)
666+
tab_mlp_item = TabMlp(
667+
column_idx=tab_preprocessor_item.column_idx,
668+
cat_embed_input=tab_preprocessor_item.cat_embed_input,
669+
continuous_cols=["price"],
670+
mlp_hidden_dims=[16, 8],
671+
mlp_dropout=[0.2, 0.2],
672+
)
673+
674+
two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method="dot")
675+
676+
model = WideDeep(deeptabular=two_tower_model)
677+
678+
trainer = Trainer(model, objective="binary")
679+
680+
trainer.fit(
681+
X_tab=[X_user, X_item],
682+
target=interactions.purchased.values,
683+
n_epochs=1,
684+
batch_size=32,
685+
)
686+
```
687+
688+
**8. Tabular with a multi-target loss**
591689

592690
This one is "a bonus" to illustrate the use of multi-target losses, more than
593691
actually a different architecture.
594692

595693
<p align="center">
596-
<img width="200" src="docs/figures/arch_7.png">
694+
<img width="200" src="docs/figures/arch_8.png">
597695
</p>
598696

599697

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.6.2
1+
1.6.3

docs/examples.rst

-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@ them to address different problems
1717
* `HyperParameter Tuning With RayTune <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/notebooks/10_hyperParameter_tuning_w_raytune_n_wnb.ipynb>`__
1818
* `Model Uncertainty Prediction <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/notebooks/13_Model_Uncertainty_prediction.ipynb>`__
1919
* `Bayesian Models <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/notebooks/14_bayesian_models.ipynb>`__
20-
* `Deep Imbalanced Regression <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/notebooks/15_DIR-LDS_and_FDS.ipynb>`__
2120

docs/figures/arch_7.png

23.8 KB
Loading

docs/figures/arch_8.png

32.6 KB
Loading

examples/notebooks/15_DIR-LDS_and_FDS.ipynb

-847
This file was deleted.

examples/scripts/california_housing_fds_lds.py

-53
This file was deleted.

examples/scripts/readme_snippets.py

+84-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,90 @@ def output_dim(self):
407407
)
408408

409409

410-
# 7. Simply Tabular with a multi-target loss
410+
# 7. A Two tower model
411+
np.random.seed(42)
412+
413+
# user_features dataframe
414+
user_ids = np.arange(1, 101)
415+
ages = np.random.randint(18, 60, size=100)
416+
genders = np.random.choice(["male", "female"], size=100)
417+
locations = np.random.choice(["city_a", "city_b", "city_c", "city_d"], size=100)
418+
user_features = pd.DataFrame(
419+
{"id": user_ids, "age": ages, "gender": genders, "location": locations}
420+
)
421+
422+
# item_features dataframe
423+
item_ids = np.arange(1, 101)
424+
prices = np.random.uniform(10, 500, size=100).round(2)
425+
colors = np.random.choice(["red", "blue", "green", "black"], size=100)
426+
categories = np.random.choice(["electronics", "clothing", "home", "toys"], size=100)
427+
428+
item_features = pd.DataFrame(
429+
{"id": item_ids, "price": prices, "color": colors, "category": categories}
430+
)
431+
432+
# Interactions dataframe
433+
interaction_user_ids = np.random.choice(user_ids, size=1000)
434+
interaction_item_ids = np.random.choice(item_ids, size=1000)
435+
purchased = np.random.choice([0, 1], size=1000, p=[0.7, 0.3])
436+
interactions = pd.DataFrame(
437+
{
438+
"user_id": interaction_user_ids,
439+
"item_id": interaction_item_ids,
440+
"purchased": purchased,
441+
}
442+
)
443+
user_item_purchased = interactions.merge(
444+
user_features, left_on="user_id", right_on="id"
445+
).merge(item_features, left_on="item_id", right_on="id")
446+
447+
448+
# Users
449+
tab_preprocessor_user = TabPreprocessor(
450+
cat_embed_cols=["gender", "location"],
451+
continuous_cols=["age"],
452+
)
453+
X_user = tab_preprocessor_user.fit_transform(user_item_purchased)
454+
tab_mlp_user = TabMlp(
455+
column_idx=tab_preprocessor_user.column_idx,
456+
cat_embed_input=tab_preprocessor_user.cat_embed_input,
457+
continuous_cols=["age"],
458+
mlp_hidden_dims=[16, 8],
459+
mlp_dropout=[0.2, 0.2],
460+
)
461+
462+
# Items
463+
tab_preprocessor_item = TabPreprocessor(
464+
cat_embed_cols=["color", "category"],
465+
continuous_cols=["price"],
466+
)
467+
X_item = tab_preprocessor_item.fit_transform(user_item_purchased)
468+
tab_mlp_item = TabMlp(
469+
column_idx=tab_preprocessor_item.column_idx,
470+
cat_embed_input=tab_preprocessor_item.cat_embed_input,
471+
continuous_cols=["price"],
472+
mlp_hidden_dims=[16, 8],
473+
mlp_dropout=[0.2, 0.2],
474+
)
475+
476+
two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method="dot")
477+
478+
model = WideDeep(deeptabular=two_tower_model)
479+
480+
trainer = Trainer(
481+
model,
482+
objective="binary",
483+
)
484+
485+
trainer.fit(
486+
X_tab=[X_user, X_item],
487+
target=interactions.purchased.values,
488+
n_epochs=1,
489+
batch_size=32,
490+
)
491+
492+
493+
# 8. Simply Tabular with a multi-target loss
411494

412495
# let's add a second target to the dataframe
413496
df["target2"] = [random.choice([0, 1]) for _ in range(100)]

mkdocs/mkdocs.yml

+8-9
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,14 @@ nav:
5252
- 12_ZILNLoss_origkeras_vs_pytorch_widedeep: examples/12_ZILNLoss_origkeras_vs_pytorch_widedeep.ipynb
5353
- 13_model_uncertainty_prediction: examples/13_model_uncertainty_prediction.ipynb
5454
- 14_bayesian_models: examples/14_bayesian_models.ipynb
55-
- 15_DIR-LDS_and_FDS: examples/15_DIR-LDS_and_FDS.ipynb
56-
- 16_Self-Supervised Pre-Training pt 1: examples/16_Self_Supervised_Pretraning_pt1.ipynb
57-
- 16_Self-Supervised Pre-Training pt 2: examples/16_Self_Supervised_Pretraning_pt2.ipynb
58-
- 17_Usign-a-custom-hugging-face-model: examples/17_Usign_a_custom_hugging_face_model.ipynb
59-
- 18_feature_importance_via_attention_weights: examples/18_feature_importance_via_attention_weights.ipynb
60-
- 19_wide_and_deep_for_recsys_pt1: examples/19_wide_and_deep_for_recsys_pt1.ipynb
61-
- 19_wide_and_deep_for_recsys_pt2: examples/19_wide_and_deep_for_recsys_pt2.ipynb
62-
- 20_load_from_folder_functionality: examples/20_load_from_folder_functionality.ipynb
63-
- 21-Using-huggingface-within-widedeep: examples/21_Using_huggingface_within_widedeep.ipynb
55+
- 15_Self-Supervised Pre-Training pt 1: examples/16_Self_Supervised_Pretraning_pt1.ipynb
56+
- 15_Self-Supervised Pre-Training pt 2: examples/16_Self_Supervised_Pretraning_pt2.ipynb
57+
- 16_Usign-a-custom-hugging-face-model: examples/17_Usign_a_custom_hugging_face_model.ipynb
58+
- 17_feature_importance_via_attention_weights: examples/18_feature_importance_via_attention_weights.ipynb
59+
- 18_wide_and_deep_for_recsys_pt1: examples/19_wide_and_deep_for_recsys_pt1.ipynb
60+
- 18_wide_and_deep_for_recsys_pt2: examples/19_wide_and_deep_for_recsys_pt2.ipynb
61+
- 19_load_from_folder_functionality: examples/20_load_from_folder_functionality.ipynb
62+
- 20-Using-huggingface-within-widedeep: examples/21_Using_huggingface_within_widedeep.ipynb
6463
- Contributing: contributing.md
6564

6665
theme:

0 commit comments

Comments
 (0)