@@ -587,13 +587,111 @@ trainer.fit(
587
587
)
588
588
```
589
589
590
- ** 7. Tabular with a multi-target loss**
590
+ ** 7. A two-tower model**
591
+
592
+ This is a popular model in the context of recommendation systems. Let's say we
593
+ have a tabular dataset formed my triples (user features, item features,
594
+ target). We can create a two-tower model where the user and item features are
595
+ passed through two separate models and then "fused" via a dot product.
596
+
597
+ <p align =" center " >
598
+ <img width =" 350 " src =" docs/figures/arch_7.png " >
599
+ </p >
600
+
601
+
602
+ ``` python
603
+ import numpy as np
604
+ import pandas as pd
605
+
606
+ from pytorch_widedeep import Trainer
607
+ from pytorch_widedeep.preprocessing import TabPreprocessor
608
+ from pytorch_widedeep.models import TabMlp, WideDeep, ModelFuser
609
+
610
+ # Let's create the interaction dataset
611
+ # user_features dataframe
612
+ np.random.seed(42 )
613
+ user_ids = np.arange(1 , 101 )
614
+ ages = np.random.randint(18 , 60 , size = 100 )
615
+ genders = np.random.choice([" male" , " female" ], size = 100 )
616
+ locations = np.random.choice([" city_a" , " city_b" , " city_c" , " city_d" ], size = 100 )
617
+ user_features = pd.DataFrame(
618
+ {" id" : user_ids, " age" : ages, " gender" : genders, " location" : locations}
619
+ )
620
+
621
+ # item_features dataframe
622
+ item_ids = np.arange(1 , 101 )
623
+ prices = np.random.uniform(10 , 500 , size = 100 ).round(2 )
624
+ colors = np.random.choice([" red" , " blue" , " green" , " black" ], size = 100 )
625
+ categories = np.random.choice([" electronics" , " clothing" , " home" , " toys" ], size = 100 )
626
+
627
+ item_features = pd.DataFrame(
628
+ {" id" : item_ids, " price" : prices, " color" : colors, " category" : categories}
629
+ )
630
+
631
+ # Interactions dataframe
632
+ interaction_user_ids = np.random.choice(user_ids, size = 1000 )
633
+ interaction_item_ids = np.random.choice(item_ids, size = 1000 )
634
+ purchased = np.random.choice([0 , 1 ], size = 1000 , p = [0.7 , 0.3 ])
635
+ interactions = pd.DataFrame(
636
+ {
637
+ " user_id" : interaction_user_ids,
638
+ " item_id" : interaction_item_ids,
639
+ " purchased" : purchased,
640
+ }
641
+ )
642
+ user_item_purchased = interactions.merge(
643
+ user_features, left_on = " user_id" , right_on = " id"
644
+ ).merge(item_features, left_on = " item_id" , right_on = " id" )
645
+
646
+ # Users
647
+ tab_preprocessor_user = TabPreprocessor(
648
+ cat_embed_cols = [" gender" , " location" ],
649
+ continuous_cols = [" age" ],
650
+ )
651
+ X_user = tab_preprocessor_user.fit_transform(user_item_purchased)
652
+ tab_mlp_user = TabMlp(
653
+ column_idx = tab_preprocessor_user.column_idx,
654
+ cat_embed_input = tab_preprocessor_user.cat_embed_input,
655
+ continuous_cols = [" age" ],
656
+ mlp_hidden_dims = [16 , 8 ],
657
+ mlp_dropout = [0.2 , 0.2 ],
658
+ )
659
+
660
+ # Items
661
+ tab_preprocessor_item = TabPreprocessor(
662
+ cat_embed_cols = [" color" , " category" ],
663
+ continuous_cols = [" price" ],
664
+ )
665
+ X_item = tab_preprocessor_item.fit_transform(user_item_purchased)
666
+ tab_mlp_item = TabMlp(
667
+ column_idx = tab_preprocessor_item.column_idx,
668
+ cat_embed_input = tab_preprocessor_item.cat_embed_input,
669
+ continuous_cols = [" price" ],
670
+ mlp_hidden_dims = [16 , 8 ],
671
+ mlp_dropout = [0.2 , 0.2 ],
672
+ )
673
+
674
+ two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method = " dot" )
675
+
676
+ model = WideDeep(deeptabular = two_tower_model)
677
+
678
+ trainer = Trainer(model, objective = " binary" )
679
+
680
+ trainer.fit(
681
+ X_tab = [X_user, X_item],
682
+ target = interactions.purchased.values,
683
+ n_epochs = 1 ,
684
+ batch_size = 32 ,
685
+ )
686
+ ```
687
+
688
+ ** 8. Tabular with a multi-target loss**
591
689
592
690
This one is "a bonus" to illustrate the use of multi-target losses, more than
593
691
actually a different architecture.
594
692
595
693
<p align =" center " >
596
- <img width =" 200 " src =" docs/figures/arch_7 .png " >
694
+ <img width =" 200 " src =" docs/figures/arch_8 .png " >
597
695
</p >
598
696
599
697
0 commit comments