From 1eaaded0eea36e59e3d20f94efb6e7e0df053aea Mon Sep 17 00:00:00 2001 From: Javier Date: Sat, 5 Oct 2024 22:24:43 +0100 Subject: [PATCH 01/24] bayesian functionalities updated to support MPS backend --- .../_base_contrastive_denoising_trainer.py | 3 ++- .../_base_encoder_decoder_trainer.py | 3 ++- pytorch_widedeep/tab2vec.py | 14 +++++++++++--- .../training/_base_bayesian_trainer.py | 3 ++- pytorch_widedeep/training/bayesian_trainer.py | 18 ++++++++++++++---- pytorch_widedeep/training/trainer.py | 3 ++- .../training/trainer_from_folder.py | 5 +++-- pytorch_widedeep/utils/general_utils.py | 11 +++++++++++ .../test_b_miscellaneous.py | 6 ++++-- .../test_bayes_tab2vec/test_b_t2v.py | 5 +++-- .../test_mc_attn_layers.py | 3 ++- tests/test_tab2vec/test_t2v.py | 4 ++-- 12 files changed, 58 insertions(+), 20 deletions(-) diff --git a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py index 95706b50..11c9b835 100644 --- a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py @@ -27,6 +27,7 @@ CallbackContainer, LRShedulerCallback, ) +from pytorch_widedeep.utils.general_utils import setup_device from pytorch_widedeep.models.tabular.self_supervised import ContrastiveDenoisingModel from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor @@ -270,7 +271,7 @@ def _set_device_and_num_workers(**kwargs): if sys.platform == "darwin" and sys.version_info.minor > 7 else os.cpu_count() ) - default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + default_device = setup_device() device = kwargs.get("device", default_device) num_workers = kwargs.get("num_workers", default_num_workers) return device, num_workers diff --git a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py index fe0aa4a8..97491b4f 100644 --- a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py @@ -25,6 +25,7 @@ CallbackContainer, LRShedulerCallback, ) +from pytorch_widedeep.utils.general_utils import setup_device from pytorch_widedeep.models.tabular.self_supervised import EncoderDecoderModel @@ -245,7 +246,7 @@ def _set_device_and_num_workers(**kwargs): if sys.platform == "darwin" and sys.version_info.minor > 7 else os.cpu_count() ) - default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + default_device = setup_device() device = kwargs.get("device", default_device) num_workers = kwargs.get("num_workers", default_num_workers) return device, num_workers diff --git a/pytorch_widedeep/tab2vec.py b/pytorch_widedeep/tab2vec.py index af3d9273..5ba2ab48 100644 --- a/pytorch_widedeep/tab2vec.py +++ b/pytorch_widedeep/tab2vec.py @@ -10,8 +10,6 @@ from pytorch_widedeep.bayesian_models import BayesianWide, BayesianTabMlp from pytorch_widedeep.bayesian_models._base_bayesian_model import BaseBayesianModel -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class Tab2Vec: r"""Class to transform an input dataframe into vectorized form. @@ -87,6 +85,7 @@ def __init__( self, tab_preprocessor: TabPreprocessor, model: Union[WideDeep, BayesianWide, BayesianTabMlp], + device: str | torch.device, return_dataframe: bool = False, verbose: bool = False, ): @@ -94,6 +93,11 @@ def __init__( self._check_inputs(tab_preprocessor, model, verbose) + if isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + self.tab_preprocessor = tab_preprocessor self.return_dataframe = return_dataframe self.verbose = verbose @@ -158,7 +162,11 @@ def transform( """ X_tab = self.tab_preprocessor.transform(df) - X = torch.from_numpy(X_tab.astype("float")).to(device) + + if self.device.type == "mps": + X = torch.from_numpy(X_tab.astype("float32")).to(self.device) + else: + X = torch.from_numpy(X_tab.astype("float")).to(self.device) with torch.no_grad(): if self.is_tab_transformer: diff --git a/pytorch_widedeep/training/_base_bayesian_trainer.py b/pytorch_widedeep/training/_base_bayesian_trainer.py index 02993ef9..fa508531 100644 --- a/pytorch_widedeep/training/_base_bayesian_trainer.py +++ b/pytorch_widedeep/training/_base_bayesian_trainer.py @@ -25,6 +25,7 @@ CallbackContainer, LRShedulerCallback, ) +from pytorch_widedeep.utils.general_utils import setup_device from pytorch_widedeep.training._trainer_utils import bayesian_alias_to_loss from pytorch_widedeep.bayesian_models._base_bayesian_model import BaseBayesianModel @@ -235,7 +236,7 @@ def _set_device_and_num_workers(**kwargs): if sys.platform == "darwin" and sys.version_info.minor > 7 else os.cpu_count() ) - default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + default_device = setup_device() device = kwargs.get("device", default_device) num_workers = kwargs.get("num_workers", default_num_workers) return device, num_workers diff --git a/pytorch_widedeep/training/bayesian_trainer.py b/pytorch_widedeep/training/bayesian_trainer.py index 88c426fe..65020c23 100644 --- a/pytorch_widedeep/training/bayesian_trainer.py +++ b/pytorch_widedeep/training/bayesian_trainer.py @@ -86,7 +86,8 @@ class BayesianTrainer(BaseBayesianTrainer): Other infrequently used arguments that can also be passed as kwargs are: - **device**: `str`
- string indicating the device. One of _'cpu'_ or _'gpu'_ + string indicating the device. One of _'cpu'_, _'gpu'_ or 'mps' if + run on a Mac with Apple silicon or AMD GPU(s) - **num_workers**: `int`
number of workers to be used internally by the data loaders @@ -396,7 +397,10 @@ def _train_step( ): self.model.train() - X = X_tab.to(self.device) + try: + X = X_tab.to(self.device) + except TypeError: + X = X_tab.to(self.device, dtype=torch.float32) y = target.view(-1, 1).float() if self.objective != "multiclass" else target y = y.to(self.device) @@ -424,7 +428,10 @@ def _eval_step( ): self.model.eval() with torch.no_grad(): - X = X_tab.to(self.device) + try: + X = X_tab.to(self.device) + except TypeError: + X = X_tab.to(self.device, dtype=torch.float32) y = target.view(-1, 1).float() if self.objective != "multiclass" else target y = y.to(self.device) @@ -479,7 +486,10 @@ def _predict( # noqa: C901 for _, Xl in zip(tt, test_loader): tt.set_description("predict") - X = Xl[0].to(self.device) + try: + X = Xl[0].to(self.device) + except TypeError: + X = Xl[0].to(self.device, dtype=torch.float32) if return_samples: preds = torch.stack([self.model(X) for _ in range(n_samples)]) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index f4dd1a91..873a11b4 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -151,7 +151,8 @@ class Trainer(BaseTrainer): Other infrequently used arguments that can also be passed as kwargs are: - **device**: `str`
- string indicating the device. One of _'cpu'_ or _'gpu'_ + string indicating the device. One of _'cpu'_, _'gpu'_ or 'mps' if + run on a Mac with Apple silicon or AMD GPU(s) - **num_workers**: `int`
number of workers to be used internally by the data loaders diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py index e2159b71..0641f3d8 100644 --- a/pytorch_widedeep/training/trainer_from_folder.py +++ b/pytorch_widedeep/training/trainer_from_folder.py @@ -156,8 +156,9 @@ class TrainerFromFolder(BaseTrainer): **kwargs: dict Other infrequently used arguments that can also be passed as kwargs are: - - **device**: `str`
- string indicating the device. One of _'cpu'_ or _'gpu'_ + - **device**: `str`
+ string indicating the device. One of _'cpu'_, _'gpu'_ or 'mps' if + run on a Mac with Apple silicon or AMD GPU(s) - **num_workers**: `int`
number of workers to be used internally by the data loaders diff --git a/pytorch_widedeep/utils/general_utils.py b/pytorch_widedeep/utils/general_utils.py index 6db4b4b3..cac80e84 100644 --- a/pytorch_widedeep/utils/general_utils.py +++ b/pytorch_widedeep/utils/general_utils.py @@ -1,5 +1,16 @@ from functools import wraps +import torch + + +def setup_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + def alias(original_name: str, alternative_names: list): def decorator(func): diff --git a/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py b/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py index 5b472a21..9e0c6629 100644 --- a/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py +++ b/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py @@ -66,7 +66,8 @@ def test_save_and_load(model): ) new_model = torch.load( - "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt" + "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt", + weights_only=False, ) if model.__class__.__name__ == "BayesianWide": @@ -113,7 +114,8 @@ def test_save_and_load_dict(model_name): btrainer2.model.load_state_dict( torch.load( - "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt" + "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt", + weights_only=False, ) ) diff --git a/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py b/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py index 17d71f9f..ee82b931 100644 --- a/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py +++ b/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py @@ -2,15 +2,15 @@ from random import choices import numpy as np -import torch import pandas as pd import pytest from pytorch_widedeep import Tab2Vec from pytorch_widedeep.preprocessing import TabPreprocessor from pytorch_widedeep.bayesian_models import BayesianTabMlp +from pytorch_widedeep.utils.general_utils import setup_device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = setup_device() colnames = list(string.ascii_lowercase)[:4] + ["target"] cat_col1_vals = ["a", "b", "c"] @@ -55,6 +55,7 @@ def test_bayesian_mlp_models(return_dataframe, embed_continuous): t2v = Tab2Vec( tab_preprocessor=tab_preprocessor, model=model, + device=device, return_dataframe=return_dataframe, ) t2v_out, _ = t2v.transform(df_t2v, target_col="target") diff --git a/tests/test_model_components/test_mc_attn_layers.py b/tests/test_model_components/test_mc_attn_layers.py index e5106993..ead4b5fc 100644 --- a/tests/test_model_components/test_mc_attn_layers.py +++ b/tests/test_model_components/test_mc_attn_layers.py @@ -4,6 +4,7 @@ import torch import pytest +from pytorch_widedeep.utils.general_utils import setup_device from pytorch_widedeep.models.tabular.transformers._attention_layers import ( MultiHeadedAttention, ) @@ -12,7 +13,7 @@ input_dim = 128 n_heads = 4 -device = "cuda" if torch.cuda.is_available() else "cpu" +device = setup_device() standard_attn = ( MultiHeadedAttention( diff --git a/tests/test_tab2vec/test_t2v.py b/tests/test_tab2vec/test_t2v.py index a466aadf..5a92b9e7 100644 --- a/tests/test_tab2vec/test_t2v.py +++ b/tests/test_tab2vec/test_t2v.py @@ -2,7 +2,6 @@ from random import choices import numpy as np -import torch import pandas as pd import pytest @@ -21,8 +20,9 @@ ContextAttentionMLP, ) from pytorch_widedeep.preprocessing import TabPreprocessor +from pytorch_widedeep.utils.general_utils import setup_device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = setup_device() colnames = list(string.ascii_lowercase)[:4] + ["target"] cat_col1_vals = ["a", "b", "c"] From 69a1fab263e6fd070239d1435a0c375610c3bfcd Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 8 Oct 2024 18:10:26 +0100 Subject: [PATCH 02/24] Support for mps backend added. All unit test passing. Need to run some example scripts --- pytorch_widedeep/losses.py | 26 ++++++++-- .../training/_base_bayesian_trainer.py | 8 ++-- pytorch_widedeep/training/_base_trainer.py | 11 +++-- .../training/_feature_importance.py | 8 ++-- pytorch_widedeep/training/_finetune.py | 15 +++--- pytorch_widedeep/training/bayesian_trainer.py | 16 ++----- pytorch_widedeep/training/trainer.py | 18 +++---- .../training/trainer_from_folder.py | 18 +++---- pytorch_widedeep/utils/general_utils.py | 47 +++++++++++++++++-- .../test_finetune/test_finetuning_routines.py | 16 +++---- tests/test_hf_integration/test_models.py | 1 + tests/test_tab2vec/test_t2v.py | 4 ++ 12 files changed, 118 insertions(+), 70 deletions(-) diff --git a/pytorch_widedeep/losses.py b/pytorch_widedeep/losses.py index 2674f1ed..5f5a5769 100644 --- a/pytorch_widedeep/losses.py +++ b/pytorch_widedeep/losses.py @@ -272,8 +272,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: else: num_class = input_prob.size(1) binary_target = torch.eye(num_class)[target.squeeze().cpu().long()] - if use_cuda: - binary_target = binary_target.cuda() + binary_target = binary_target.to(input.device) binary_target = binary_target.contiguous() weight = self._get_weight(input_prob, binary_target) @@ -430,9 +429,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: # when using max the two input tensors (input and other) have to be of # the same type max_input = F.softplus(input[..., 2:]) - max_other = torch.sqrt(torch.Tensor([torch.finfo(torch.double).eps])).type( - max_input.type() - ) + max_other = self.get_eps(max_input) scale = torch.max(max_input, max_other) safe_labels = positive * target + (1 - positive) * torch.ones_like(target) @@ -446,6 +443,25 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: return torch.mean(classification_loss + regression_loss) + @staticmethod + def get_eps(max_input: Tensor) -> Tensor: + if max_input.device.type == "mps": + # For MPS, use float32 and then convert to the input type + eps = torch.finfo(torch.float32).eps + max_other = ( + torch.sqrt(torch.tensor([eps], device="cpu")) + .to(max_input.device) + .to(max_input.dtype) + ) + else: + # For other devices, use the original approach + eps = torch.finfo(torch.double).eps + max_other = ( + torch.sqrt(torch.tensor([eps])).to(max_input.device).to(max_input.dtype) + ) + + return max_other + class L1Loss(nn.Module): r"""L1 loss""" diff --git a/pytorch_widedeep/training/_base_bayesian_trainer.py b/pytorch_widedeep/training/_base_bayesian_trainer.py index fa508531..bbdf8255 100644 --- a/pytorch_widedeep/training/_base_bayesian_trainer.py +++ b/pytorch_widedeep/training/_base_bayesian_trainer.py @@ -12,6 +12,7 @@ from pytorch_widedeep.wdtypes import ( Any, List, + Tuple, Union, Module, Optional, @@ -25,7 +26,7 @@ CallbackContainer, LRShedulerCallback, ) -from pytorch_widedeep.utils.general_utils import setup_device +from pytorch_widedeep.utils.general_utils import setup_device, to_device_model from pytorch_widedeep.training._trainer_utils import bayesian_alias_to_loss from pytorch_widedeep.bayesian_models._base_bayesian_model import BaseBayesianModel @@ -58,8 +59,7 @@ def __init__( self.device, self.num_workers = self._set_device_and_num_workers(**kwargs) self.early_stop = False - self.model = model - self.model.to(self.device) + self.model = to_device_model(model, self.device) self.verbose = verbose self.seed = seed @@ -230,7 +230,7 @@ def _set_callbacks_and_metrics(self, callbacks: Any, metrics: Any): self.callback_container.set_trainer(self) @staticmethod - def _set_device_and_num_workers(**kwargs): + def _set_device_and_num_workers(**kwargs) -> Tuple[str, int]: default_num_workers = ( 0 if sys.platform == "darwin" and sys.version_info.minor > 7 diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py index 74a57202..45cb2e9e 100644 --- a/pytorch_widedeep/training/_base_trainer.py +++ b/pytorch_widedeep/training/_base_trainer.py @@ -15,6 +15,7 @@ Any, Dict, List, + Tuple, Union, Module, Optional, @@ -31,6 +32,7 @@ LRShedulerCallback, ) from pytorch_widedeep.initializers import Initializer, MultipleInitializer +from pytorch_widedeep.utils.general_utils import setup_device, to_device_model from pytorch_widedeep.training._trainer_utils import alias_to_loss from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer from pytorch_widedeep.training._multiple_transforms import MultipleTransforms @@ -74,10 +76,9 @@ def __init__( self.verbose = verbose self.seed = seed - self.model = model + self.model = to_device_model(model, self.device) if self.model.is_tabnet: self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3) - self.model.to(self.device) self.model.wd_device = self.device self.objective = objective @@ -406,7 +407,7 @@ def _check_inputs( ) @staticmethod - def _set_device_and_num_workers(**kwargs): + def _set_device_and_num_workers(**kwargs) -> Tuple[str, int]: # Important note for Mac users: Since python 3.8, the multiprocessing # library start method changed from 'fork' to 'spawn'. This affects the # data-loaders, which will not run in parallel. @@ -415,9 +416,9 @@ def _set_device_and_num_workers(**kwargs): if sys.platform == "darwin" and sys.version_info.minor > 7 else os.cpu_count() ) - default_device = "cuda" if torch.cuda.is_available() else "cpu" - device = kwargs.get("device", default_device) num_workers = kwargs.get("num_workers", default_num_workers) + default_device = setup_device() + device = kwargs.get("device", default_device) return device, num_workers def __repr__(self) -> str: # noqa: C901 diff --git a/pytorch_widedeep/training/_feature_importance.py b/pytorch_widedeep/training/_feature_importance.py index 79a37f19..8bf9588b 100644 --- a/pytorch_widedeep/training/_feature_importance.py +++ b/pytorch_widedeep/training/_feature_importance.py @@ -24,7 +24,7 @@ SelfAttentionMLP, ContextAttentionMLP, ) -from pytorch_widedeep.utils.general_utils import alias +from pytorch_widedeep.utils.general_utils import alias, to_device from pytorch_widedeep.training._wd_dataset import WideDeepDataset from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix @@ -136,7 +136,7 @@ def _sample_data(self, loader: DataLoader) -> Tensor: batches = [] for i, (data, _) in enumerate(loader): if i < n_iterations: - batches.append(data["deeptabular"].to(self.device)) + batches.append(to_device(data["deeptabular"], self.device)) else: break @@ -307,7 +307,7 @@ def explain( m_explain_l = [] for batch_nb, data in enumerate(loader): - X = data["deeptabular"].to(self.device) + X = to_device(data["deeptabular"], self.device) M_explain, masks = model_backbone.forward_masks(X) # type: ignore[operator] m_explain_l.append( csc_matrix.dot(M_explain.cpu().detach().numpy(), reducing_matrix) @@ -363,7 +363,7 @@ def explain( batch_feat_imp: Any = [] for _, data in enumerate(loader): - X = data["deeptabular"].to(self.device) + X = to_device(data["deeptabular"], self.device) _ = model.deeptabular(X) feat_imp, _ = self._feature_importance(model) diff --git a/pytorch_widedeep/training/_finetune.py b/pytorch_widedeep/training/_finetune.py index 90579179..5567fc42 100644 --- a/pytorch_widedeep/training/_finetune.py +++ b/pytorch_widedeep/training/_finetune.py @@ -16,10 +16,9 @@ DataLoader, LRScheduler, ) +from pytorch_widedeep.utils.general_utils import to_device, setup_device from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent -use_cuda = torch.cuda.is_available() - WDModel = Union[nn.Module, BaseWDModelComponent] @@ -67,6 +66,8 @@ def __init__( self.method = method self.verbose = verbose + self.device = setup_device() + def finetune_all( self, model: Union[WDModel, nn.ModuleList], @@ -432,20 +433,16 @@ def _train( # noqa: C901 data, target = packed_data if idx is not None: - X = ( - data[model_name][idx].cuda() - if use_cuda - else data[model_name][idx] - ) + X = to_device(data[model_name][idx], self.device) else: - X = data[model_name].cuda() if use_cuda else data[model_name] + X = to_device(data[model_name], self.device) y = ( target.view(-1, 1).float() if self.method not in ["multiclass", "qregression"] else target ) - y = y.cuda() if use_cuda else y + y = to_device(y, self.device) optimizer.zero_grad() y_pred = model(X) diff --git a/pytorch_widedeep/training/bayesian_trainer.py b/pytorch_widedeep/training/bayesian_trainer.py index 65020c23..3c314934 100644 --- a/pytorch_widedeep/training/bayesian_trainer.py +++ b/pytorch_widedeep/training/bayesian_trainer.py @@ -20,7 +20,7 @@ LRScheduler, ) from pytorch_widedeep.callbacks import Callback -from pytorch_widedeep.utils.general_utils import alias +from pytorch_widedeep.utils.general_utils import alias, to_device from pytorch_widedeep.training._trainer_utils import ( save_epoch_logs, print_loss_and_metric, @@ -397,12 +397,9 @@ def _train_step( ): self.model.train() - try: - X = X_tab.to(self.device) - except TypeError: - X = X_tab.to(self.device, dtype=torch.float32) + X = to_device(X_tab, self.device) y = target.view(-1, 1).float() if self.objective != "multiclass" else target - y = y.to(self.device) + y = to_device(y, self.device) self.optimizer.zero_grad() y_pred, loss = self.model.sample_elbo(X, y, self.loss_fn, n_samples, n_batches) # type: ignore[arg-type] @@ -428,12 +425,9 @@ def _eval_step( ): self.model.eval() with torch.no_grad(): - try: - X = X_tab.to(self.device) - except TypeError: - X = X_tab.to(self.device, dtype=torch.float32) + X = to_device(X_tab, self.device) y = target.view(-1, 1).float() if self.objective != "multiclass" else target - y = y.to(self.device) + y = to_device(y, self.device) y_pred, loss = self.model.sample_elbo( X, # type: ignore[arg-type] diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 873a11b4..8ca7d459 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -26,7 +26,7 @@ from pytorch_widedeep.callbacks import Callback from pytorch_widedeep.initializers import Initializer from pytorch_widedeep.training._finetune import FineTune -from pytorch_widedeep.utils.general_utils import alias +from pytorch_widedeep.utils.general_utils import alias, to_device from pytorch_widedeep.training._wd_dataset import WideDeepDataset from pytorch_widedeep.training._base_trainer import BaseTrainer from pytorch_widedeep.training._trainer_utils import ( @@ -914,15 +914,15 @@ def _train_step( X: Dict[str, Union[Tensor, List[Tensor]]] = {} for k, v in data.items(): if isinstance(v, list): - X[k] = [i.to(self.device) for i in v] + X[k] = [to_device(i, self.device) for i in v] else: - X[k] = v.to(self.device) + X[k] = to_device(v, self.device) y = ( target.view(-1, 1).float() if self.method not in ["multiclass", "qregression", "multitarget"] else target ) - y = y.to(self.device) + y = to_device(y, self.device) self.optimizer.zero_grad() @@ -954,15 +954,15 @@ def _eval_step( X: Dict[str, Union[Tensor, List[Tensor]]] = {} for k, v in data.items(): if isinstance(v, list): - X[k] = [i.to(self.device) for i in v] + X[k] = [to_device(i, self.device) for i in v] else: - X[k] = v.to(self.device) + X[k] = to_device(v, self.device) y = ( target.view(-1, 1).float() if self.method not in ["multiclass", "qregression", "multitarget"] else target ) - y = y.to(self.device) + y = to_device(y, self.device) y_pred = self.model(X) if self.model.is_tabnet: @@ -1061,9 +1061,9 @@ def _predict( # type: ignore[override, return] # noqa: C901 X: Dict[str, Union[Tensor, List[Tensor]]] = {} for k, v in data.items(): if isinstance(v, list): - X[k] = [i.to(self.device) for i in v] + X[k] = [to_device(i, self.device) for i in v] else: - X[k] = v.to(self.device) + X[k] = to_device(v, self.device) preds = ( self.model(X) if not self.model.is_tabnet diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py index 0641f3d8..1ebc2629 100644 --- a/pytorch_widedeep/training/trainer_from_folder.py +++ b/pytorch_widedeep/training/trainer_from_folder.py @@ -23,7 +23,7 @@ from pytorch_widedeep.callbacks import Callback from pytorch_widedeep.initializers import Initializer from pytorch_widedeep.training._finetune import FineTune -from pytorch_widedeep.utils.general_utils import alias +from pytorch_widedeep.utils.general_utils import alias, to_device from pytorch_widedeep.training._wd_dataset import WideDeepDataset from pytorch_widedeep.training._base_trainer import BaseTrainer from pytorch_widedeep.training._trainer_utils import ( @@ -525,15 +525,15 @@ def _train_step( X: Dict[str, Union[Tensor, List[Tensor]]] = {} for k, v in data.items(): if isinstance(v, list): - X[k] = [i.to(self.device) for i in v] + X[k] = [to_device(i, self.device) for i in v] else: - X[k] = v.to(self.device) + X[k] = to_device(v, self.device) y = ( target.view(-1, 1).float() if self.method not in ["multiclass", "qregression"] else target ) - y = y.to(self.device) + y = to_device(y, self.device) self.optimizer.zero_grad() @@ -560,15 +560,15 @@ def _eval_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): X: Dict[str, Union[Tensor, List[Tensor]]] = {} for k, v in data.items(): if isinstance(v, list): - X[k] = [i.to(self.device) for i in v] + X[k] = [to_device(i, self.device) for i in v] else: - X[k] = v.to(self.device) + X[k] = to_device(v, self.device) y = ( target.view(-1, 1).float() if self.method not in ["multiclass", "qregression"] else target ) - y = y.to(self.device) + y = to_device(y, self.device) y_pred = self.model(X) if self.model.is_tabnet: # pragma: no cover @@ -670,9 +670,9 @@ def _predict( # noqa: C901 X: Dict[str, Union[Tensor, List[Tensor]]] = {} for k, v in data.items(): if isinstance(v, list): - X[k] = [i.to(self.device) for i in v] + X[k] = [to_device(i, self.device) for i in v] else: - X[k] = v.to(self.device) + X[k] = to_device(v, self.device) preds = ( self.model(X) if not self.model.is_tabnet diff --git a/pytorch_widedeep/utils/general_utils.py b/pytorch_widedeep/utils/general_utils.py index cac80e84..781439ef 100644 --- a/pytorch_widedeep/utils/general_utils.py +++ b/pytorch_widedeep/utils/general_utils.py @@ -1,15 +1,54 @@ from functools import wraps import torch +from torch import Tensor -def setup_device(): +def setup_device() -> str: if torch.cuda.is_available(): - return torch.device("cuda") + return "cuda" elif torch.backends.mps.is_available(): - return torch.device("mps") + return "mps" else: - return torch.device("cpu") + return "cpu" + + +def to_device(X: Tensor, device: str) -> Tensor: + # Adjustmet in case the backend is mps which does not support float64 + if device == "mps" and X.dtype == torch.float64: + X = X.float() + return X.to(device) + + +def to_device_model(model, device: str): # noqa: C901 + # insistent transformation since it some cases overall approaches such as + # model.to('mps') do not work + + if device in ["cpu", "cuda"]: + return model.to(device) + + if device == "mps": + + try: + return model.to(device) + except (RuntimeError, TypeError): + + def convert(t): + if isinstance(t, torch.Tensor): + return t.float().to(device) + return t + + model.apply(lambda module: module._apply(convert)) + + for param in model.parameters(): + if param.device.type != "mps": + param.data = param.data.float().to(device) + + for buffer_name, buffer in model.named_buffers(): + if buffer.device.type != "mps": + model._buffers[buffer_name] = buffer.float().to(device) + + return model def alias(original_name: str, alternative_names: list): diff --git a/tests/test_finetune/test_finetuning_routines.py b/tests/test_finetune/test_finetuning_routines.py index b64eacc5..7e03f8c8 100644 --- a/tests/test_finetune/test_finetuning_routines.py +++ b/tests/test_finetune/test_finetuning_routines.py @@ -1,7 +1,6 @@ import string import numpy as np -import torch import pytest import torch.nn.functional as F from torch import nn @@ -10,10 +9,11 @@ from pytorch_widedeep.models import Wide, TabMlp from pytorch_widedeep.metrics import Accuracy, MultipleMetrics from pytorch_widedeep.training._finetune import FineTune +from pytorch_widedeep.utils.general_utils import setup_device, to_device_model from pytorch_widedeep.models.image._layers import conv_layer from pytorch_widedeep.training._wd_dataset import WideDeepDataset -use_cuda = torch.cuda.is_available() +device = setup_device() # Define a series of simple models to quickly test the FineTune class @@ -87,8 +87,7 @@ def loss_fn(y_pred, y_true): # wide/linear wide = Wide(np.unique(X_wide).size, 1) -if use_cuda: - wide.cuda() +wide = to_device_model(wide, device) # tabular tab_mlp = TabMlp( @@ -99,18 +98,15 @@ def loss_fn(y_pred, y_true): mlp_dropout=0.2, ) tab_mlp = nn.Sequential(tab_mlp, nn.Linear(8, 1)) # type: ignore[assignment] -if use_cuda: - tab_mlp.cuda() +tab_mlp = to_device_model(tab_mlp, device) # text deeptext = TextModeTestClass() -if use_cuda: - deeptext.cuda() +deeptext = to_device_model(deeptext, device) # image deepimage = ImageModeTestClass() -if use_cuda: - deepimage.cuda() +ddeimage = to_device_model(deepimage, device) # Define the loader mmset = WideDeepDataset(X_wide, X_tab, X_text, X_image, target) diff --git a/tests/test_hf_integration/test_models.py b/tests/test_hf_integration/test_models.py index 847731d8..0c86039e 100644 --- a/tests/test_hf_integration/test_models.py +++ b/tests/test_hf_integration/test_models.py @@ -161,6 +161,7 @@ def test_full_training_process(model_name): model, objective="binary", verbose=0, + # device="cpu", ) trainer.fit( diff --git a/tests/test_tab2vec/test_t2v.py b/tests/test_tab2vec/test_t2v.py index 5a92b9e7..cc9b2746 100644 --- a/tests/test_tab2vec/test_t2v.py +++ b/tests/test_tab2vec/test_t2v.py @@ -85,6 +85,7 @@ def test_non_transformer_models(deeptabular, return_dataframe): # Let's assume the model is trained t2v = Tab2Vec( tab_preprocessor=tab_preprocessor, + device=device, model=model, return_dataframe=return_dataframe, ) @@ -163,6 +164,7 @@ def test_tab_transformer_models( t2v = Tab2Vec( tab_preprocessor=tab_preprocessor, model=model, + device=device, ) x_vec = t2v.transform(df_t2v) @@ -229,6 +231,7 @@ def test_attentive_mlp( t2v = Tab2Vec( tab_preprocessor=tab_preprocessor, model=model, + device=device, ) x_vec = t2v.transform(df_t2v) @@ -305,6 +308,7 @@ def test_transformer_family_models( tab_preprocessor=tab_preprocessor, model=model, return_dataframe=return_dataframe, + device=device, ) x_vec = t2v.transform(df_t2v) From 73894d4a45de727c47005a1b0a7a9252ece8a9b1 Mon Sep 17 00:00:00 2001 From: Javier Date: Fri, 11 Oct 2024 18:25:58 +0100 Subject: [PATCH 03/24] The 1st version of the DINPreprocessor seems to work. Need to test it --- pytorch_widedeep/models/rec/din.py | 2 +- pytorch_widedeep/preprocessing/__init__.py | 1 + .../preprocessing/din_preprocessor.py | 451 ++++++++++++++++++ 3 files changed, 453 insertions(+), 1 deletion(-) create mode 100644 pytorch_widedeep/preprocessing/din_preprocessor.py diff --git a/pytorch_widedeep/models/rec/din.py b/pytorch_widedeep/models/rec/din.py index 1eba4f97..293f249e 100644 --- a/pytorch_widedeep/models/rec/din.py +++ b/pytorch_widedeep/models/rec/din.py @@ -242,8 +242,8 @@ def __init__( self, *, column_idx: Dict[str, int], - target_item_col: str, user_behavior_confiq: Tuple[List[str], int, int], + target_item_col: str = "target_item", action_seq_config: Optional[Tuple[List[str], int]] = None, other_seq_cols_confiq: Optional[List[Tuple[List[str], int, int]]] = None, attention_unit_activation: Literal["prelu", "dice"] = "prelu", diff --git a/pytorch_widedeep/preprocessing/__init__.py b/pytorch_widedeep/preprocessing/__init__.py index 59f8b787..b4c5bcee 100644 --- a/pytorch_widedeep/preprocessing/__init__.py +++ b/pytorch_widedeep/preprocessing/__init__.py @@ -5,6 +5,7 @@ from pytorch_widedeep.preprocessing.tab_preprocessor import ( TabPreprocessor, ChunkTabPreprocessor, + embed_sz_rule, ) from pytorch_widedeep.preprocessing.text_preprocessor import ( TextPreprocessor, diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py new file mode 100644 index 00000000..04655a82 --- /dev/null +++ b/pytorch_widedeep/preprocessing/din_preprocessor.py @@ -0,0 +1,451 @@ +import re +from typing import Dict, List, Tuple, Union, Literal, Optional + +import numpy as np +import pandas as pd + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import WideDeep +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.datasets import load_movielens100k +from pytorch_widedeep.preprocessing import TabPreprocessor, embed_sz_rule +from pytorch_widedeep.models.rec.din import DeepInterestNetwork +from pytorch_widedeep.utils.text_utils import pad_sequences +from pytorch_widedeep.utils.deeptabular_utils import LabelEncoder +from pytorch_widedeep.preprocessing.base_preprocessor import ( + BasePreprocessor, + check_is_fitted, +) + + +class DINPreprocessor(BasePreprocessor): + def __init__( + self, + *, + user_id_col: str, + target_col: str, + item_embed_col: Union[str, Tuple[str, int]], + max_seq_length: int, + action_col: Optional[str] = None, + other_seq_embed_cols: Optional[List[str] | List[Tuple[str, int]]] = None, + cat_embed_cols: Optional[Union[List[str], List[Tuple[str, int]]]] = None, + continuous_cols: Optional[List[str]] = None, + quantization_setup: Optional[ + Union[int, Dict[str, Union[int, List[float]]]] + ] = None, + cols_to_scale: Optional[Union[List[str], str]] = None, + auto_embed_dim: bool = True, + embedding_rule: Literal["google", "fastai_old", "fastai_new"] = "fastai_new", + default_embed_dim: int = 16, + verbose: int = 1, + scale: bool = False, + already_standard: Optional[List[str]] = None, + **kwargs, + ): + + self.user_id_col = user_id_col + self.item_embed_col = item_embed_col + self.max_seq_length = max_seq_length + self.target_col = target_col if target_col is not None else "target" + self.action_col = action_col + self.other_seq_embed_cols = other_seq_embed_cols + self.cat_embed_cols = cat_embed_cols + self.continuous_cols = continuous_cols + self.quantization_setup = quantization_setup + self.cols_to_scale = cols_to_scale + self.auto_embed_dim = auto_embed_dim + self.embedding_rule = embedding_rule + self.default_embed_dim = default_embed_dim + self.verbose = verbose + self.scale = scale + self.already_standard = already_standard + self.kwargs = kwargs + + self.has_standard_tab_data = False + if self.cat_embed_cols or self.continuous_cols: + self.has_standard_tab_data = True + self.tab_preprocessor = TabPreprocessor( + cat_embed_cols=self.cat_embed_cols, + continuous_cols=self.continuous_cols, + quantization_setup=self.quantization_setup, + cols_to_scale=self.cols_to_scale, + auto_embed_dim=self.auto_embed_dim, + embedding_rule=self.embedding_rule, + default_embed_dim=self.default_embed_dim, + verbose=self.verbose, + scale=self.scale, + already_standard=self.already_standard, + **self.kwargs, + ) + + self.is_fitted = False + + def fit(self, df: pd.DataFrame): + + if self.has_standard_tab_data: + self.tab_preprocessor.fit(df) + self.din_columns_idx = { + col: i for i, col in enumerate(self.tab_preprocessor.column_idx.keys()) + } + else: + self.din_columns_idx = {} + + self.item_le = LabelEncoder( + columns_to_encode=[ + ( + self.item_embed_col[0] + if isinstance(self.item_embed_col, tuple) + else self.item_embed_col + ) + ] + ) + self.item_le.fit(df) + self.n_items = max(self.item_le.encoding_dict[self.item_embed_col[0]].values()) + user_behaviour_embed_size = ( + self.item_embed_col[1] + if isinstance(self.item_embed_col, tuple) + else embed_sz_rule(self.n_items) + ) + self.user_behaviour_config: Tuple[List[str], int, int] = ( + [f"item_{i+1}" for i in range(self.max_seq_length)], + self.n_items, + user_behaviour_embed_size, + ) + + _current_len = len(self.din_columns_idx) + self.din_columns_idx.update( + {f"item_{i+1}": i + _current_len for i in range(self.max_seq_length)} + ) + self.din_columns_idx.update( + { + "target_item": len(self.din_columns_idx), + } + ) + + if self.action_col is not None: + self.action_le = LabelEncoder(columns_to_encode=[self.action_col]) + self.action_le.fit(df) + + self.n_actions = len(self.action_le.encoding_dict[self.action_col]) + self.action_seq_config: Tuple[List[str], int] = ( + [f"action_{i+1}" for i in range(self.max_seq_length)], + self.n_actions, + ) + + _current_len = len(self.din_columns_idx) + self.din_columns_idx.update( + {f"action_{i+1}": i + _current_len for i in range(self.max_seq_length)} + ) + + if self.other_seq_embed_cols is not None: + self.other_seq_le = LabelEncoder( + columns_to_encode=[ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + ) + self.other_seq_le.fit(df) + + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + self.n_other_seq_cols: Dict[str, int] = { + col: max(self.other_seq_le.encoding_dict[col].values()) + for col in other_seq_cols + } + + if isinstance(self.other_seq_embed_cols[0], tuple): + other_seq_embed_sizes: Dict[str, int] = { + tp[0]: tp[1] for tp in self.other_seq_embed_cols # type: ignore[misc] + } + else: + other_seq_embed_sizes = { + col: embed_sz_rule(self.n_other_seq_cols[col]) + for col in other_seq_cols + } + + self.other_seq_config: List[Tuple[List[str], int, int]] = [ + ( + [f"{col}_{i+1}" for i in range(self.max_seq_length)], + self.n_other_seq_cols[col], + other_seq_embed_sizes[col], + ) + for col in other_seq_cols + ] + + _current_len = len(self.din_columns_idx) + for col in other_seq_cols: + self.din_columns_idx.update( + { + f"{col}_{i+1}": i + _current_len + for i in range(self.max_seq_length) + } + ) + self.din_columns_idx.update( + {f"target_{col}": len(self.din_columns_idx)} + ) + _current_len = len(self.din_columns_idx) + + self.is_fitted = True + + return self + + def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: + + _df = self._pre_transform(df) + + df_w_sequences = self._build_sequences(_df) + + user_behaviour_seq = df_w_sequences["items_sequence"].tolist() + X_user_behaviour = np.vstack( + [ + pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0) + for seq in user_behaviour_seq + ] + ) + + X_target_item = np.array(df_w_sequences["target_item"].tolist()).reshape(-1, 1) + + X_all = np.concatenate([X_user_behaviour, X_target_item], axis=1) + + if self.has_standard_tab_data: + X_tab = self.tab_preprocessor.transform(df_w_sequences) + X_all = np.concatenate([X_tab, X_all], axis=1) + + if self.action_col is not None: + action_seq = df_w_sequences["actions_sequence"].tolist() + X_actions = np.vstack( + [ + pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0) + for seq in action_seq + ] + ) + X_all = np.concatenate([X_all, X_actions], axis=1) + + if self.other_seq_embed_cols is not None: + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + other_seq = { + col: df_w_sequences[f"{col}_sequence"].tolist() + for col in other_seq_cols + } + other_seq_target = { + col: df_w_sequences[f"target_{col}"].tolist() for col in other_seq_cols + } + X_other_seq_arrays: List[np.ndarray] = [] + # to obsessively make sure that the order of the columns is the + # same + for col in other_seq_cols: + X_other_seq_arrays.append( + np.vstack( + [ + pad_sequences( + s, self.max_seq_length, pad_first=False, pad_idx=0 + ) + for s in other_seq[col] + ] + ) + ) + X_other_seq_arrays.append( + np.array(other_seq_target[col]).reshape(-1, 1) + ) + + X_all = np.concatenate([X_all] + X_other_seq_arrays, axis=1) + + assert len(self.din_columns_idx) == X_all.shape[1], ( + f"Something went wrong. The number of columns in the final array " + f"({X_all.shape[1]}) is different from the number of columns in " + f"self.din_columns_idx ({len(self.din_columns_idx)})" + ) + + return X_all, np.array(df_w_sequences["target"].tolist()) + + def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: + self.fit(df) + return self.transform(df) + + def _pre_transform(self, df: pd.DataFrame) -> pd.DataFrame: + + check_is_fitted(self, attributes=["din_columns_idx"]) + + df = self.item_le.transform(df) + + if self.action_col is not None: + df = self.action_le.transform(df) + + if self.other_seq_embed_cols is not None: + df = self.other_seq_le.transform(df) + + return df + + def _build_sequences(self, df: pd.DataFrame) -> pd.DataFrame: + + res_df = ( + df.groupby(self.user_id_col) + .apply(self._group_sequences) + .reset_index(drop=True) + ) + + return res_df + + def _group_sequences(self, group: pd.DataFrame) -> pd.DataFrame: + + item_col = ( + self.item_embed_col[0] + if isinstance(self.item_embed_col, tuple) + else self.item_embed_col + ) + items = group[item_col].tolist() + + targets = group[self.target_col].tolist() + drop_cols = [item_col, self.target_col] + + sequences: List[Dict[str, str | List[int]]] = [] + for i in range(len(items) - self.max_seq_length): + item_sequences = items[i : i + self.max_seq_length] + target_item = items[i + self.max_seq_length] + target = targets[i + self.max_seq_length] + + sequence = { + "user_id": group.name, + "items_sequence": item_sequences, + "target_item": target_item, + } + + if self.action_col is not None: + if self.action_col != self.target_col: + actions = group[self.action_col].tolist() + action_sequences = actions[i : i + self.max_seq_length] + drop_cols.append(self.action_col) + else: + target -= ( + 1 # the 'transform' method adds 1 as it saves 0 for padding + ) + action_sequences = targets[i : i + self.max_seq_length] + + sequence["target"] = target + sequence["actions_sequence"] = action_sequences + + if self.other_seq_embed_cols is not None: + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + drop_cols += other_seq_cols + other_seqs: Dict[str, List[int]] = { + col: group[col].tolist() for col in other_seq_cols + } + other_seqs_sequences = { + col: other_seqs[col][i : i + self.max_seq_length] + for col in other_seq_cols + } + + for col in other_seq_cols: + sequence[f"{col}_sequence"] = other_seqs_sequences[col] + sequence[f"target_{col}"] = other_seqs[col][i + self.max_seq_length] + + sequences.append(sequence) + + seq_df = pd.DataFrame(sequences) + + non_seq_cols = group.drop_duplicates(["user_id"]).drop(drop_cols, axis=1) + + return pd.merge(seq_df, non_seq_cols, on="user_id") + + def inverse_transform(self, X: np.ndarray) -> pd.DataFrame: + # trasform mutates the df and is complex to revert the transformation + raise NotImplementedError( + "inverse_transform is not implemented for this preprocessor" + ) + + +if __name__ == "__main__": + + def clean_genre_list(genre_list): + return "_".join( + sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list]) + ) + + data, users, items = load_movielens100k(as_frame=True) + + list_of_genres = [ + "unknown", + "Action", + "Adventure", + "Animation", + "Children's", + "Comedy", + "Crime", + "Documentary", + "Drama", + "Fantasy", + "Film-Noir", + "Horror", + "Musical", + "Mystery", + "Romance", + "Sci-Fi", + "Thriller", + "War", + "Western", + ] + + assert ( + isinstance(items, pd.DataFrame) + and isinstance(data, pd.DataFrame) + and isinstance(users, pd.DataFrame) + ) + items["genre_list"] = items[list_of_genres].apply( + lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1 + ) + + items["genre_list"] = items["genre_list"].apply(clean_genre_list) + + df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id") + df = pd.merge( + df, + users[["user_id", "age", "gender", "occupation"]], + on="user_id", + ) + + df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0) + + df = df.sort_values(by=["timestamp"]).reset_index(drop=True) + df = df.drop("timestamp", axis=1) + + din_preprocessor = DINPreprocessor( + user_id_col="user_id", + target_col="rating", + item_embed_col=("movie_id", 16), + max_seq_length=5, + action_col="rating", + other_seq_embed_cols=[("genre_list", 16)], + cat_embed_cols=["user_id", "age", "gender", "occupation"], + ) + + X, y = din_preprocessor.fit_transform(df) + + din = DeepInterestNetwork( + column_idx=din_preprocessor.din_columns_idx, + user_behavior_confiq=din_preprocessor.user_behaviour_config, + action_seq_config=din_preprocessor.action_seq_config, + other_seq_cols_confiq=din_preprocessor.other_seq_config, + cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined] + mlp_hidden_dims=[128, 64], + ) + + # And from here on, everything is standard + model = WideDeep(deeptabular=din) + + trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()]) + + # in the real world you would have to split the data into train, val and test + trainer.fit( + X_tab=X, + target=y, + n_epochs=5, + batch_size=512, + ) From 45fce4d207ecb70b339d4b9b4d04be8504af92b8 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 14 Oct 2024 09:55:33 +0100 Subject: [PATCH 04/24] first test added for the din preprocessor --- pytorch_widedeep/preprocessing/__init__.py | 1 + .../preprocessing/din_preprocessor.py | 768 ++++++++++-------- .../interactions_data.csv | 113 +++ .../generate_interactions_data.py | 194 +++++ .../test_data_utils/test_din_preprocessor.py | 87 ++ 5 files changed, 842 insertions(+), 321 deletions(-) create mode 100644 tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv create mode 100644 tests/test_data_utils/generate_interactions_data.py create mode 100644 tests/test_data_utils/test_din_preprocessor.py diff --git a/pytorch_widedeep/preprocessing/__init__.py b/pytorch_widedeep/preprocessing/__init__.py index b4c5bcee..f52a5e73 100644 --- a/pytorch_widedeep/preprocessing/__init__.py +++ b/pytorch_widedeep/preprocessing/__init__.py @@ -2,6 +2,7 @@ HFPreprocessor, ChunkHFPreprocessor, ) +from pytorch_widedeep.preprocessing.din_preprocessor import DINPreprocessor from pytorch_widedeep.preprocessing.tab_preprocessor import ( TabPreprocessor, ChunkTabPreprocessor, diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py index 04655a82..d28e6926 100644 --- a/pytorch_widedeep/preprocessing/din_preprocessor.py +++ b/pytorch_widedeep/preprocessing/din_preprocessor.py @@ -1,17 +1,22 @@ -import re +# import re +import warnings from typing import Dict, List, Tuple, Union, Literal, Optional import numpy as np import pandas as pd -from pytorch_widedeep import Trainer -from pytorch_widedeep.models import WideDeep -from pytorch_widedeep.metrics import Accuracy -from pytorch_widedeep.datasets import load_movielens100k -from pytorch_widedeep.preprocessing import TabPreprocessor, embed_sz_rule -from pytorch_widedeep.models.rec.din import DeepInterestNetwork +# from pytorch_widedeep.models import WideDeep +# from pytorch_widedeep.metrics import Accuracy +# from pytorch_widedeep.datasets import load_movielens100k +# from pytorch_widedeep.training import Trainer +# from pytorch_widedeep.models.rec.din import DeepInterestNetwork from pytorch_widedeep.utils.text_utils import pad_sequences +from pytorch_widedeep.utils.general_utils import alias from pytorch_widedeep.utils.deeptabular_utils import LabelEncoder +from pytorch_widedeep.preprocessing.tab_preprocessor import ( + TabPreprocessor, + embed_sz_rule, +) from pytorch_widedeep.preprocessing.base_preprocessor import ( BasePreprocessor, check_is_fitted, @@ -19,15 +24,94 @@ class DINPreprocessor(BasePreprocessor): + """ + Preprocessor for Deep Interest Network (DIN) models. + + This preprocessor handles the preparation of data for DIN models, + including sequence building, label encoding, and handling of various + types of input columns (categorical, continuous, and sequential). + + Parameters: + ----------- + user_id_col : str + Name of the column containing user IDs. + item_embed_col : Union[str, Tuple[str, int]] + Name of the column containing item IDs to be embedded, or a tuple of + (column_name, embedding_dim). + target_col : str + Name of the column containing the target variable. + max_seq_length : int + Maximum length of sequences to be created. + action_col : Optional[str], default=None + Name of the column containing user actions (if applicable). + other_seq_embed_cols : Optional[List[str] | List[Tuple[str, int]]], default=None + List of other columns to be treated as sequences. + cat_embed_cols : Optional[Union[List[str], List[Tuple[str, int]]]], default=None + List of categorical columns to be represented by embeddings. + continuous_cols: List, default = None + List with the name of the continuous cols. + quantization_setup: int or Dict[str, Union[int, List[float]]], default=None + Continuous columns can be turned into categorical via `pd.cut`. + cols_to_scale: List or str, default = None, + List with the names of the columns that will be standardized via + sklearn's `StandardScaler`. + auto_embed_dim: bool, default = True + Boolean indicating whether the embedding dimensions will be + automatically defined via rule of thumb. + embedding_rule: str, default = 'fastai_new' + Rule of thumb for embedding size. + default_embed_dim: int, default=16 + Default dimension for the embeddings. + verbose : int, default=1 + Verbosity level. + scale: bool, default = False + Boolean indicating whether or not to scale/standardize continuous cols. + already_standard: List, default = None + List with the name of the continuous cols that do not need to be + scaled/standardized. + **kwargs : + Additional keyword arguments to be passed to the TabPreprocessor. + + Attributes: + ----------- + is_fitted : bool + Whether the preprocessor has been fitted. + has_standard_tab_data : bool + Whether the data includes standard tabular data. + tab_preprocessor : TabPreprocessor + Preprocessor for standard tabular data. + din_columns_idx : Dict[str, int] + Dictionary mapping column names to their indices in the processed data. + item_le : LabelEncoder + Label encoder for item IDs. + n_items : int + Number of unique items. + user_behaviour_config : Tuple[List[str], int, int] + Configuration for user behavior sequences. + action_le : LabelEncoder + Label encoder for action column (if applicable). + n_actions : int + Number of unique actions (if applicable). + action_seq_config : Tuple[List[str], int] + Configuration for action sequences (if applicable). + other_seq_le : LabelEncoder + Label encoder for other sequence columns (if applicable). + n_other_seq_cols : Dict[str, int] + Number of unique values in each other sequence column. + other_seq_config : List[Tuple[List[str], int, int]] + Configuration for other sequence columns. + """ + + @alias("item_embed_col", ["item_id_col"]) def __init__( self, *, user_id_col: str, - target_col: str, item_embed_col: Union[str, Tuple[str, int]], + target_col: str, max_seq_length: int, action_col: Optional[str] = None, - other_seq_embed_cols: Optional[List[str] | List[Tuple[str, int]]] = None, + other_seq_embed_cols: Optional[Union[List[str], List[Tuple[str, int]]]] = None, cat_embed_cols: Optional[Union[List[str], List[Tuple[str, int]]]] = None, continuous_cols: Optional[List[str]] = None, quantization_setup: Optional[ @@ -42,7 +126,6 @@ def __init__( already_standard: Optional[List[str]] = None, **kwargs, ): - self.user_id_col = user_id_col self.item_embed_col = item_embed_col self.max_seq_length = max_seq_length @@ -61,9 +144,8 @@ def __init__( self.already_standard = already_standard self.kwargs = kwargs - self.has_standard_tab_data = False - if self.cat_embed_cols or self.continuous_cols: - self.has_standard_tab_data = True + self.has_standard_tab_data = bool(self.cat_embed_cols or self.continuous_cols) + if self.has_standard_tab_data: self.tab_preprocessor = TabPreprocessor( cat_embed_cols=self.cat_embed_cols, continuous_cols=self.continuous_cols, @@ -80,8 +162,7 @@ def __init__( self.is_fitted = False - def fit(self, df: pd.DataFrame): - + def fit(self, df: pd.DataFrame) -> "DINPreprocessor": if self.has_standard_tab_data: self.tab_preprocessor.fit(df) self.din_columns_idx = { @@ -90,362 +171,407 @@ def fit(self, df: pd.DataFrame): else: self.din_columns_idx = {} - self.item_le = LabelEncoder( - columns_to_encode=[ - ( - self.item_embed_col[0] - if isinstance(self.item_embed_col, tuple) - else self.item_embed_col - ) - ] + self._fit_label_encoders(df) + self.is_fitted = True + return self + + def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: + _df = self._pre_transform(df) + df_w_sequences = self._build_sequences(_df) + X_all = self._concatenate_features(df_w_sequences) + return X_all, np.array(df_w_sequences["target"].tolist()) + + def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: + self.fit(df) + return self.transform(df) + + def inverse_transform(self, X: np.ndarray) -> pd.DataFrame: + raise NotImplementedError( + "inverse_transform is not implemented for this preprocessor" ) + + def _fit_label_encoders(self, df: pd.DataFrame) -> None: + self.item_le = LabelEncoder(columns_to_encode=[self._get_item_col()]) self.item_le.fit(df) - self.n_items = max(self.item_le.encoding_dict[self.item_embed_col[0]].values()) - user_behaviour_embed_size = ( - self.item_embed_col[1] - if isinstance(self.item_embed_col, tuple) - else embed_sz_rule(self.n_items) + self.n_items = max(self.item_le.encoding_dict[self._get_item_col()].values()) + user_behaviour_embed_size = self._get_embed_size( + self.item_embed_col, self.n_items ) - self.user_behaviour_config: Tuple[List[str], int, int] = ( + self.user_behaviour_config = ( [f"item_{i+1}" for i in range(self.max_seq_length)], self.n_items, user_behaviour_embed_size, ) - _current_len = len(self.din_columns_idx) - self.din_columns_idx.update( - {f"item_{i+1}": i + _current_len for i in range(self.max_seq_length)} - ) - self.din_columns_idx.update( - { - "target_item": len(self.din_columns_idx), - } + self._update_din_columns_idx("item", self.max_seq_length) + self.din_columns_idx["target_item"] = len(self.din_columns_idx) + + if self.action_col: + self._fit_action_label_encoder(df) + if self.other_seq_embed_cols: + self._fit_other_seq_label_encoders(df) + + def _fit_action_label_encoder(self, df: pd.DataFrame) -> None: + self.action_le = LabelEncoder(columns_to_encode=[self.action_col]) + self.action_le.fit(df) + self.n_actions = len(self.action_le.encoding_dict[self.action_col]) + self.action_seq_config = ( + [f"action_{i+1}" for i in range(self.max_seq_length)], + self.n_actions, ) + self._update_din_columns_idx("action", self.max_seq_length) + + def _other_seq_cols_float_warning(self, df: pd.DataFrame) -> None: + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + for col in other_seq_cols: + if df[col].dtype == float: + warnings.warn( + f"{col} is a float column. It will be converted to integers. " + "If this is not what you want, please convert it beforehand.", + UserWarning, + ) - if self.action_col is not None: - self.action_le = LabelEncoder(columns_to_encode=[self.action_col]) - self.action_le.fit(df) - - self.n_actions = len(self.action_le.encoding_dict[self.action_col]) - self.action_seq_config: Tuple[List[str], int] = ( - [f"action_{i+1}" for i in range(self.max_seq_length)], - self.n_actions, - ) - - _current_len = len(self.din_columns_idx) - self.din_columns_idx.update( - {f"action_{i+1}": i + _current_len for i in range(self.max_seq_length)} - ) - - if self.other_seq_embed_cols is not None: - self.other_seq_le = LabelEncoder( - columns_to_encode=[ - col[0] if isinstance(col, tuple) else col - for col in self.other_seq_embed_cols - ] - ) - self.other_seq_le.fit(df) - - other_seq_cols = [ + def _convert_other_seq_cols_to_int(self, df: pd.DataFrame) -> None: + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + for col in other_seq_cols: + if df[col].dtype == float: + df[col] = df[col].astype(int) + + def _fit_other_seq_label_encoders(self, df: pd.DataFrame) -> None: + self._other_seq_cols_float_warning(df) + self._convert_other_seq_cols_to_int(df) + self.other_seq_le = LabelEncoder( + columns_to_encode=[ col[0] if isinstance(col, tuple) else col for col in self.other_seq_embed_cols ] - self.n_other_seq_cols: Dict[str, int] = { - col: max(self.other_seq_le.encoding_dict[col].values()) - for col in other_seq_cols - } - - if isinstance(self.other_seq_embed_cols[0], tuple): - other_seq_embed_sizes: Dict[str, int] = { - tp[0]: tp[1] for tp in self.other_seq_embed_cols # type: ignore[misc] - } - else: - other_seq_embed_sizes = { - col: embed_sz_rule(self.n_other_seq_cols[col]) - for col in other_seq_cols - } - - self.other_seq_config: List[Tuple[List[str], int, int]] = [ - ( - [f"{col}_{i+1}" for i in range(self.max_seq_length)], - self.n_other_seq_cols[col], - other_seq_embed_sizes[col], - ) - for col in other_seq_cols - ] - - _current_len = len(self.din_columns_idx) - for col in other_seq_cols: - self.din_columns_idx.update( - { - f"{col}_{i+1}": i + _current_len - for i in range(self.max_seq_length) - } - ) - self.din_columns_idx.update( - {f"target_{col}": len(self.din_columns_idx)} - ) - _current_len = len(self.din_columns_idx) - - self.is_fitted = True - - return self - - def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: - - _df = self._pre_transform(df) - - df_w_sequences = self._build_sequences(_df) - - user_behaviour_seq = df_w_sequences["items_sequence"].tolist() - X_user_behaviour = np.vstack( - [ - pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0) - for seq in user_behaviour_seq - ] ) - - X_target_item = np.array(df_w_sequences["target_item"].tolist()).reshape(-1, 1) - - X_all = np.concatenate([X_user_behaviour, X_target_item], axis=1) - - if self.has_standard_tab_data: - X_tab = self.tab_preprocessor.transform(df_w_sequences) - X_all = np.concatenate([X_tab, X_all], axis=1) - - if self.action_col is not None: - action_seq = df_w_sequences["actions_sequence"].tolist() - X_actions = np.vstack( - [ - pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0) - for seq in action_seq - ] + self.other_seq_le.fit(df) + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + self.n_other_seq_cols = { + col: max(self.other_seq_le.encoding_dict[col].values()) + for col in other_seq_cols + } + other_seq_embed_sizes = { + col: self._get_embed_size(col, self.n_other_seq_cols[col]) + for col in other_seq_cols + } + self.other_seq_config = [ + ( + self._get_seq_col_names(col), + self.n_other_seq_cols[col], + other_seq_embed_sizes[col], ) - X_all = np.concatenate([X_all, X_actions], axis=1) + for col in other_seq_cols + ] + self._update_din_columns_idx_for_other_seq(other_seq_cols) - if self.other_seq_embed_cols is not None: - other_seq_cols = [ - col[0] if isinstance(col, tuple) else col - for col in self.other_seq_embed_cols - ] - other_seq = { - col: df_w_sequences[f"{col}_sequence"].tolist() - for col in other_seq_cols - } - other_seq_target = { - col: df_w_sequences[f"target_{col}"].tolist() for col in other_seq_cols - } - X_other_seq_arrays: List[np.ndarray] = [] - # to obsessively make sure that the order of the columns is the - # same - for col in other_seq_cols: - X_other_seq_arrays.append( - np.vstack( - [ - pad_sequences( - s, self.max_seq_length, pad_first=False, pad_idx=0 - ) - for s in other_seq[col] - ] - ) - ) - X_other_seq_arrays.append( - np.array(other_seq_target[col]).reshape(-1, 1) - ) + def _get_item_col(self) -> str: + return ( + self.item_embed_col[0] + if isinstance(self.item_embed_col, tuple) + else self.item_embed_col + ) - X_all = np.concatenate([X_all] + X_other_seq_arrays, axis=1) + def _get_embed_size(self, col: Union[str, Tuple[str, int]], n_unique: int) -> int: + return col[1] if isinstance(col, tuple) else embed_sz_rule(n_unique) - assert len(self.din_columns_idx) == X_all.shape[1], ( - f"Something went wrong. The number of columns in the final array " - f"({X_all.shape[1]}) is different from the number of columns in " - f"self.din_columns_idx ({len(self.din_columns_idx)})" - ) + def _get_seq_col_names(self, col: str) -> List[str]: + return [f"{col}_{i+1}" for i in range(self.max_seq_length)] - return X_all, np.array(df_w_sequences["target"].tolist()) + def _update_din_columns_idx(self, prefix: str, length: int) -> None: + current_len = len(self.din_columns_idx) + self.din_columns_idx.update( + {f"{prefix}_{i+1}": i + current_len for i in range(length)} + ) - def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: - self.fit(df) - return self.transform(df) + def _update_din_columns_idx_for_other_seq(self, other_seq_cols: List[str]) -> None: + current_len = len(self.din_columns_idx) + for col in other_seq_cols: + self.din_columns_idx.update( + {f"{col}_{i+1}": i + current_len for i in range(self.max_seq_length)} + ) + self.din_columns_idx[f"target_{col}"] = len(self.din_columns_idx) + current_len = len(self.din_columns_idx) def _pre_transform(self, df: pd.DataFrame) -> pd.DataFrame: - check_is_fitted(self, attributes=["din_columns_idx"]) - df = self.item_le.transform(df) - - if self.action_col is not None: + if self.action_col: df = self.action_le.transform(df) - - if self.other_seq_embed_cols is not None: + if self.other_seq_embed_cols: df = self.other_seq_le.transform(df) - return df def _build_sequences(self, df: pd.DataFrame) -> pd.DataFrame: - - res_df = ( + # TO DO: do something to avoid warning here + return ( df.groupby(self.user_id_col) .apply(self._group_sequences) .reset_index(drop=True) ) - return res_df - def _group_sequences(self, group: pd.DataFrame) -> pd.DataFrame: - - item_col = ( - self.item_embed_col[0] - if isinstance(self.item_embed_col, tuple) - else self.item_embed_col - ) + item_col = self._get_item_col() items = group[item_col].tolist() - targets = group[self.target_col].tolist() drop_cols = [item_col, self.target_col] - sequences: List[Dict[str, str | List[int]]] = [] - for i in range(len(items) - self.max_seq_length): - item_sequences = items[i : i + self.max_seq_length] - target_item = items[i + self.max_seq_length] - target = targets[i + self.max_seq_length] + # sequences cannot be built with user, item pairs with only one + # interaction + if len(items) <= 1: + return pd.DataFrame() - sequence = { - "user_id": group.name, - "items_sequence": item_sequences, - "target_item": target_item, - } + sequences = [ + self._create_sequence(group, items, targets, i, drop_cols) + for i in range(max(1, len(items) - self.max_seq_length)) + ] - if self.action_col is not None: - if self.action_col != self.target_col: - actions = group[self.action_col].tolist() - action_sequences = actions[i : i + self.max_seq_length] - drop_cols.append(self.action_col) - else: - target -= ( - 1 # the 'transform' method adds 1 as it saves 0 for padding - ) - action_sequences = targets[i : i + self.max_seq_length] - - sequence["target"] = target - sequence["actions_sequence"] = action_sequences - - if self.other_seq_embed_cols is not None: - other_seq_cols = [ - col[0] if isinstance(col, tuple) else col - for col in self.other_seq_embed_cols - ] - drop_cols += other_seq_cols - other_seqs: Dict[str, List[int]] = { - col: group[col].tolist() for col in other_seq_cols - } - other_seqs_sequences = { - col: other_seqs[col][i : i + self.max_seq_length] - for col in other_seq_cols - } + seq_df = pd.DataFrame(sequences) + non_seq_cols = group.drop_duplicates([self.user_id_col]).drop(drop_cols, axis=1) + return pd.merge(seq_df, non_seq_cols, on=self.user_id_col) - for col in other_seq_cols: - sequence[f"{col}_sequence"] = other_seqs_sequences[col] - sequence[f"target_{col}"] = other_seqs[col][i + self.max_seq_length] + def _create_sequence( + self, + group: pd.DataFrame, + items: List[int], + targets: List[int], + i: int, + drop_cols: List[str], + ) -> Dict[str, Union[str, List[int]]]: + end_idx = min(i + self.max_seq_length, len(items) - 1) + item_sequences = items[i:end_idx] + target_item = items[end_idx] + target = targets[end_idx] + + sequence = { + self.user_id_col: group.name, + "items_sequence": item_sequences, + "target_item": target_item, + } + + if self.action_col: + sequence.update( + self._create_action_sequence(group, targets, i, target, drop_cols) + ) + if self.other_seq_embed_cols: + sequence.update(self._create_other_seq_sequence(group, i, drop_cols)) - sequences.append(sequence) + return sequence - seq_df = pd.DataFrame(sequences) + def _create_action_sequence( + self, + group: pd.DataFrame, + targets: List[int], + i: int, + target: int, + drop_cols: List[str], + ) -> Dict[str, Union[int, List[int]]]: + if self.action_col != self.target_col: + actions = group[self.action_col].tolist() + action_sequences = ( + actions[i : i + self.max_seq_length] + if i + self.max_seq_length < len(actions) + else actions[i:-1] + ) + drop_cols.append(self.action_col) + else: + target -= 1 # the 'transform' method adds 1 as it saves 0 for padding + action_sequences = ( + targets[i : i + self.max_seq_length] + if i + self.max_seq_length < len(targets) + else targets[i:-1] + ) + + return {"target": target, "actions_sequence": action_sequences} + + def _create_other_seq_sequence( + self, group: pd.DataFrame, i: int, drop_cols: List[str] + ) -> Dict[str, Union[int, List[int]]]: + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + drop_cols += other_seq_cols + other_seqs = {col: group[col].tolist() for col in other_seq_cols} + + other_seqs_sequences: Dict[str, Union[int, List[int]]] = {} + for col in other_seq_cols: + other_seqs_sequences[col] = ( + other_seqs[col][i : i + self.max_seq_length] + if i + self.max_seq_length < len(other_seqs[col]) + else other_seqs[col][i:-1] + ) - non_seq_cols = group.drop_duplicates(["user_id"]).drop(drop_cols, axis=1) + sequence: Dict[str, Union[int, List[int]]] = {} + for col in other_seq_cols: + sequence[f"{col}_sequence"] = other_seqs_sequences[col] + sequence[f"target_{col}"] = ( + other_seqs[col][i + self.max_seq_length] + if i + self.max_seq_length < len(other_seqs[col]) + else other_seqs[col][-1] + ) - return pd.merge(seq_df, non_seq_cols, on="user_id") + return sequence - def inverse_transform(self, X: np.ndarray) -> pd.DataFrame: - # trasform mutates the df and is complex to revert the transformation - raise NotImplementedError( - "inverse_transform is not implemented for this preprocessor" + def _concatenate_features(self, df_w_sequences: pd.DataFrame) -> np.ndarray: + user_behaviour_seq = df_w_sequences["items_sequence"].tolist() + X_user_behaviour = np.vstack( + [ + pad_sequences(seq, self.max_seq_length, pad_idx=0) + for seq in user_behaviour_seq + ] ) + X_target_item = np.array(df_w_sequences["target_item"].tolist()).reshape(-1, 1) + X_all = np.concatenate([X_user_behaviour, X_target_item], axis=1) + if self.has_standard_tab_data: + X_tab = self.tab_preprocessor.transform(df_w_sequences) + X_all = np.concatenate([X_tab, X_all], axis=1) -if __name__ == "__main__": + if self.action_col: + action_seq = df_w_sequences["actions_sequence"].tolist() + X_actions = np.vstack( + [ + pad_sequences(seq, self.max_seq_length, pad_idx=0) + for seq in action_seq + ] + ) + X_all = np.concatenate([X_all, X_actions], axis=1) + + if self.other_seq_embed_cols: + X_all = self._concatenate_other_seq_features(df_w_sequences, X_all) - def clean_genre_list(genre_list): - return "_".join( - sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list]) + assert len(self.din_columns_idx) == X_all.shape[1], ( + f"Something went wrong. The number of columns in the final array " + f"({X_all.shape[1]}) is different from the number of columns in " + f"self.din_columns_idx ({len(self.din_columns_idx)})" ) - data, users, items = load_movielens100k(as_frame=True) - - list_of_genres = [ - "unknown", - "Action", - "Adventure", - "Animation", - "Children's", - "Comedy", - "Crime", - "Documentary", - "Drama", - "Fantasy", - "Film-Noir", - "Horror", - "Musical", - "Mystery", - "Romance", - "Sci-Fi", - "Thriller", - "War", - "Western", - ] - - assert ( - isinstance(items, pd.DataFrame) - and isinstance(data, pd.DataFrame) - and isinstance(users, pd.DataFrame) - ) - items["genre_list"] = items[list_of_genres].apply( - lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1 - ) - - items["genre_list"] = items["genre_list"].apply(clean_genre_list) - - df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id") - df = pd.merge( - df, - users[["user_id", "age", "gender", "occupation"]], - on="user_id", - ) - - df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0) - - df = df.sort_values(by=["timestamp"]).reset_index(drop=True) - df = df.drop("timestamp", axis=1) - - din_preprocessor = DINPreprocessor( - user_id_col="user_id", - target_col="rating", - item_embed_col=("movie_id", 16), - max_seq_length=5, - action_col="rating", - other_seq_embed_cols=[("genre_list", 16)], - cat_embed_cols=["user_id", "age", "gender", "occupation"], - ) - - X, y = din_preprocessor.fit_transform(df) - - din = DeepInterestNetwork( - column_idx=din_preprocessor.din_columns_idx, - user_behavior_confiq=din_preprocessor.user_behaviour_config, - action_seq_config=din_preprocessor.action_seq_config, - other_seq_cols_confiq=din_preprocessor.other_seq_config, - cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined] - mlp_hidden_dims=[128, 64], - ) - - # And from here on, everything is standard - model = WideDeep(deeptabular=din) - - trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()]) - - # in the real world you would have to split the data into train, val and test - trainer.fit( - X_tab=X, - target=y, - n_epochs=5, - batch_size=512, - ) + return X_all + + def _concatenate_other_seq_features( + self, df_w_sequences: pd.DataFrame, X_all: np.ndarray + ) -> np.ndarray: + other_seq_cols = [ + col[0] if isinstance(col, tuple) else col + for col in self.other_seq_embed_cols + ] + other_seq = { + col: df_w_sequences[f"{col}_sequence"].tolist() for col in other_seq_cols + } + other_seq_target = { + col: df_w_sequences[f"target_{col}"].tolist() for col in other_seq_cols + } + X_other_seq_arrays = [] + + for col in other_seq_cols: + X_other_seq_arrays.append( + np.vstack( + [ + pad_sequences(s, self.max_seq_length, pad_idx=0) + for s in other_seq[col] + ] + ) + ) + X_other_seq_arrays.append(np.array(other_seq_target[col]).reshape(-1, 1)) + + return np.concatenate([X_all] + X_other_seq_arrays, axis=1) + + +# if __name__ == "__main__": + +# def clean_genre_list(genre_list): +# return "_".join( +# sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list]) +# ) + +# data, users, items = load_movielens100k(as_frame=True) + +# list_of_genres = [ +# "unknown", +# "Action", +# "Adventure", +# "Animation", +# "Children's", +# "Comedy", +# "Crime", +# "Documentary", +# "Drama", +# "Fantasy", +# "Film-Noir", +# "Horror", +# "Musical", +# "Mystery", +# "Romance", +# "Sci-Fi", +# "Thriller", +# "War", +# "Western", +# ] + +# assert ( +# isinstance(items, pd.DataFrame) +# and isinstance(data, pd.DataFrame) +# and isinstance(users, pd.DataFrame) +# ) +# items["genre_list"] = items[list_of_genres].apply( +# lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1 +# ) + +# items["genre_list"] = items["genre_list"].apply(clean_genre_list) + +# df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id") +# df = pd.merge( +# df, +# users[["user_id", "age", "gender", "occupation"]], +# on="user_id", +# ) + +# df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0) + +# df = df.sort_values(by=["timestamp"]).reset_index(drop=True) +# df = df.drop("timestamp", axis=1) + +# din_preprocessor = DINPreprocessor( +# user_id_col="user_id", +# target_col="rating", +# item_embed_col=("movie_id", 16), +# max_seq_length=5, +# action_col="rating", +# other_seq_embed_cols=[("genre_list", 16)], +# cat_embed_cols=["user_id", "age", "gender", "occupation"], +# ) + +# X, y = din_preprocessor.fit_transform(df) + +# din = DeepInterestNetwork( +# column_idx=din_preprocessor.din_columns_idx, +# user_behavior_confiq=din_preprocessor.user_behaviour_config, +# action_seq_config=din_preprocessor.action_seq_config, +# other_seq_cols_confiq=din_preprocessor.other_seq_config, +# cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined] +# mlp_hidden_dims=[128, 64], +# ) + +# # And from here on, everything is standard +# model = WideDeep(deeptabular=din) + +# trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()]) + +# # in the real world you would have to split the data into train, val and test +# trainer.fit( +# X_tab=X, +# target=y, +# n_epochs=5, +# batch_size=512, +# ) diff --git a/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv b/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv new file mode 100644 index 00000000..e7587df7 --- /dev/null +++ b/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv @@ -0,0 +1,113 @@ +user_id,item_id,age,gender,height,weight,price,category,timestamp,interaction +5,2,25,M,182.5,99.6,65.05,C,2023-01-02,1 +3,1,32,M,166.7,96.9,65.57,B,2023-01-02,0 +5,6,25,M,182.5,99.6,45.99,A,2023-01-04,0 +3,7,32,M,166.7,96.9,14.2,C,2023-01-05,1 +3,8,32,M,166.7,96.9,97.64,C,2023-01-09,0 +4,2,60,F,157.1,50.0,65.05,C,2023-01-13,1 +2,6,46,M,173.0,86.1,45.99,A,2023-01-15,1 +5,1,25,M,182.5,99.6,65.57,B,2023-01-17,0 +5,7,25,M,182.5,99.6,14.2,C,2023-01-22,0 +5,6,25,M,182.5,99.6,45.99,A,2023-01-27,0 +4,2,60,F,157.1,50.0,65.05,C,2023-01-28,1 +4,8,60,F,157.1,50.0,97.64,C,2023-01-29,1 +3,2,32,M,166.7,96.9,65.05,C,2023-02-02,1 +3,6,32,M,166.7,96.9,45.99,A,2023-02-04,1 +3,9,32,M,166.7,96.9,30.95,A,2023-02-06,0 +2,7,46,M,173.0,86.1,14.2,C,2023-02-10,1 +3,4,32,M,166.7,96.9,12.08,A,2023-02-11,1 +4,2,60,F,157.1,50.0,65.05,C,2023-02-13,1 +4,10,60,F,157.1,50.0,18.15,A,2023-02-14,1 +5,10,25,M,182.5,99.6,18.15,A,2023-02-20,0 +5,2,25,M,182.5,99.6,65.05,C,2023-02-21,1 +1,2,56,M,155.0,52.8,65.05,C,2023-02-22,0 +3,1,32,M,166.7,96.9,65.57,B,2023-03-01,0 +2,6,46,M,173.0,86.1,45.99,A,2023-03-06,1 +4,2,60,F,157.1,50.0,65.05,C,2023-03-07,1 +2,5,46,M,173.0,86.1,57.23,C,2023-03-22,1 +2,9,46,M,173.0,86.1,30.95,A,2023-04-08,0 +3,5,32,M,166.7,96.9,57.23,C,2023-04-09,1 +3,1,32,M,166.7,96.9,65.57,B,2023-04-09,0 +3,3,32,M,166.7,96.9,10.64,C,2023-04-09,0 +4,1,60,F,157.1,50.0,65.57,B,2023-04-09,0 +4,5,60,F,157.1,50.0,57.23,C,2023-04-14,0 +4,4,60,F,157.1,50.0,12.08,A,2023-04-22,0 +3,3,32,M,166.7,96.9,10.64,C,2023-04-22,0 +4,4,60,F,157.1,50.0,12.08,A,2023-04-23,0 +5,9,25,M,182.5,99.6,30.95,A,2023-04-23,1 +5,3,25,M,182.5,99.6,10.64,C,2023-04-23,1 +3,1,32,M,166.7,96.9,65.57,B,2023-04-26,0 +1,1,56,M,155.0,52.8,65.57,B,2023-05-06,0 +5,6,25,M,182.5,99.6,45.99,A,2023-05-08,0 +2,8,46,M,173.0,86.1,97.64,C,2023-05-08,0 +3,5,32,M,166.7,96.9,57.23,C,2023-05-11,1 +4,6,60,F,157.1,50.0,45.99,A,2023-05-14,1 +5,7,25,M,182.5,99.6,14.2,C,2023-05-15,0 +3,2,32,M,166.7,96.9,65.05,C,2023-05-15,1 +2,3,46,M,173.0,86.1,10.64,C,2023-05-16,1 +5,8,25,M,182.5,99.6,97.64,C,2023-05-17,1 +4,9,60,F,157.1,50.0,30.95,A,2023-05-23,1 +3,9,32,M,166.7,96.9,30.95,A,2023-05-24,0 +2,4,46,M,173.0,86.1,12.08,A,2023-05-27,0 +2,8,46,M,173.0,86.1,97.64,C,2023-05-31,0 +3,8,32,M,166.7,96.9,97.64,C,2023-06-01,0 +5,5,25,M,182.5,99.6,57.23,C,2023-06-01,1 +5,8,25,M,182.5,99.6,97.64,C,2023-06-09,1 +5,7,25,M,182.5,99.6,14.2,C,2023-06-10,0 +4,7,60,F,157.1,50.0,14.2,C,2023-06-18,0 +5,1,25,M,182.5,99.6,65.57,B,2023-06-25,0 +5,2,25,M,182.5,99.6,65.05,C,2023-06-29,1 +5,2,25,M,182.5,99.6,65.05,C,2023-07-06,1 +2,8,46,M,173.0,86.1,97.64,C,2023-07-06,0 +2,7,46,M,173.0,86.1,14.2,C,2023-07-07,1 +2,7,46,M,173.0,86.1,14.2,C,2023-07-09,1 +5,6,25,M,182.5,99.6,45.99,A,2023-07-11,0 +2,1,46,M,173.0,86.1,65.57,B,2023-07-20,0 +2,10,46,M,173.0,86.1,18.15,A,2023-07-21,1 +2,8,46,M,173.0,86.1,97.64,C,2023-07-22,0 +1,10,56,M,155.0,52.8,18.15,A,2023-07-25,1 +3,6,32,M,166.7,96.9,45.99,A,2023-07-26,1 +4,4,60,F,157.1,50.0,12.08,A,2023-07-27,0 +2,6,46,M,173.0,86.1,45.99,A,2023-07-27,1 +2,10,46,M,173.0,86.1,18.15,A,2023-08-05,1 +1,4,56,M,155.0,52.8,12.08,A,2023-08-06,0 +5,3,25,M,182.5,99.6,10.64,C,2023-08-08,1 +5,7,25,M,182.5,99.6,14.2,C,2023-08-08,0 +5,9,25,M,182.5,99.6,30.95,A,2023-08-12,1 +5,9,25,M,182.5,99.6,30.95,A,2023-08-13,1 +3,7,32,M,166.7,96.9,14.2,C,2023-08-15,1 +5,8,25,M,182.5,99.6,97.64,C,2023-08-22,1 +2,4,46,M,173.0,86.1,12.08,A,2023-08-28,0 +4,8,60,F,157.1,50.0,97.64,C,2023-08-31,1 +2,8,46,M,173.0,86.1,97.64,C,2023-09-09,0 +4,5,60,F,157.1,50.0,57.23,C,2023-09-11,0 +3,4,32,M,166.7,96.9,12.08,A,2023-09-12,1 +2,4,46,M,173.0,86.1,12.08,A,2023-09-12,0 +5,2,25,M,182.5,99.6,65.05,C,2023-09-16,1 +2,9,46,M,173.0,86.1,30.95,A,2023-09-16,0 +1,5,56,M,155.0,52.8,57.23,C,2023-09-21,1 +2,1,46,M,173.0,86.1,65.57,B,2023-09-24,0 +2,7,46,M,173.0,86.1,14.2,C,2023-09-27,1 +5,4,25,M,182.5,99.6,12.08,A,2023-10-04,0 +2,6,46,M,173.0,86.1,45.99,A,2023-10-07,1 +5,1,25,M,182.5,99.6,65.57,B,2023-10-11,0 +4,2,60,F,157.1,50.0,65.05,C,2023-10-13,1 +5,7,25,M,182.5,99.6,14.2,C,2023-10-16,0 +5,5,25,M,182.5,99.6,57.23,C,2023-10-22,1 +3,8,32,M,166.7,96.9,97.64,C,2023-10-31,0 +3,1,32,M,166.7,96.9,65.57,B,2023-10-31,0 +5,10,25,M,182.5,99.6,18.15,A,2023-10-31,0 +4,3,60,F,157.1,50.0,10.64,C,2023-11-06,0 +5,1,25,M,182.5,99.6,65.57,B,2023-11-13,0 +5,2,25,M,182.5,99.6,65.05,C,2023-11-14,1 +2,3,46,M,173.0,86.1,10.64,C,2023-11-23,1 +3,7,32,M,166.7,96.9,14.2,C,2023-11-24,1 +4,3,60,F,157.1,50.0,10.64,C,2023-12-04,0 +2,6,46,M,173.0,86.1,45.99,A,2023-12-04,1 +1,9,56,M,155.0,52.8,30.95,A,2023-12-06,1 +2,3,46,M,173.0,86.1,10.64,C,2023-12-10,1 +2,4,46,M,173.0,86.1,12.08,A,2023-12-12,0 +4,1,60,F,157.1,50.0,65.57,B,2023-12-15,0 +2,8,46,M,173.0,86.1,97.64,C,2023-12-18,0 +3,2,32,M,166.7,96.9,65.05,C,2023-12-26,1 +2,2,46,M,173.0,86.1,65.05,C,2023-12-26,0 diff --git a/tests/test_data_utils/generate_interactions_data.py b/tests/test_data_utils/generate_interactions_data.py new file mode 100644 index 00000000..df128e37 --- /dev/null +++ b/tests/test_data_utils/generate_interactions_data.py @@ -0,0 +1,194 @@ +import os +from pathlib import Path +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd + +full_path = os.path.realpath(__file__) +path = os.path.split(full_path)[0] + +save_dir = Path(path) / "data_for_rec_preprocessor" + + +def generate_sample_data(n_users=5, n_items=10, seed=42, return_df=False): + """ + Generate a sample dataset for recommendation system testing. + + Parameters: + ----------- + - n_users: int, number of users to generate (default 5) + - n_items: int, number of items to generate (default 10) + - seed: int, random seed for reproducibility (default 42) + """ + np.random.seed(seed) + + # Generate user data + users = pd.DataFrame( + { + "user_id": range(1, n_users + 1), + "age": np.random.randint(18, 65, n_users), + "gender": np.random.choice(["M", "F"], n_users), + "height": np.random.uniform(150, 200, n_users).round(1), + "weight": np.random.uniform(50, 100, n_users).round(1), + } + ) + + # Generate item data + items = pd.DataFrame( + { + "item_id": range(1, n_items + 1), + "price": np.random.uniform(10, 100, n_items).round(2), + "category": np.random.choice(["A", "B", "C"], n_items), + } + ) + + # Generate positive interactions + positive_interactions = [] + start_date = datetime(2023, 1, 1) + + for user_id in users["user_id"]: + # Randomly select 5 items for this user to interact with + user_items = np.random.choice(items["item_id"], 5, replace=False) + + if user_id == 1: + n_interactions = 3 + else: + n_interactions = np.random.randint(4, 21) + + for _ in range(n_interactions): + item_id = np.random.choice(user_items) + timestamp = start_date + timedelta(days=np.random.randint(0, 365)) + positive_interactions.append( + { + "user_id": user_id, + "item_id": item_id, + "timestamp": timestamp, + "interaction": 1, # Positive interaction + } + ) + + positive_df = pd.DataFrame(positive_interactions) + + # Generate negative interactions + negative_interactions = [] + for user_id in users["user_id"]: + positive_items = positive_df[positive_df["user_id"] == user_id][ + "item_id" + ].unique() + negative_items = items[~items["item_id"].isin(positive_items)]["item_id"] + + n_negative = len(positive_df[positive_df["user_id"] == user_id]) + + for _ in range(n_negative): + item_id = np.random.choice(negative_items) + timestamp = start_date + timedelta(days=np.random.randint(0, 365)) + negative_interactions.append( + { + "user_id": user_id, + "item_id": item_id, + "timestamp": timestamp, + "interaction": 0, # Negative interaction + } + ) + + negative_df = pd.DataFrame(negative_interactions) + + # Combine positive and negative interactions + interactions_df = pd.concat([positive_df, negative_df], ignore_index=True) + + # Merge all data + final_df = interactions_df.merge(users, on="user_id").merge(items, on="item_id") + + # Sort by timestamp + final_df = final_df.sort_values("timestamp").reset_index(drop=True) + + # Reorder columns + column_order = [ + "user_id", + "item_id", + "age", + "gender", + "height", + "weight", + "price", + "category", + "timestamp", + "interaction", + ] + final_df = final_df[column_order] + + if not save_dir.exists(): + save_dir.mkdir(parents=True) + + if not return_df: + final_df.to_csv(save_dir / "interactions_data.csv", index=False) + else: + return final_df + + +def split_by_timestamp(df, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1): + """ + Split the dataframe into train, validation, and test sets based on timestamp. + + Parameters: + - df: pd.DataFrame, the input dataframe + - train_ratio: float, ratio of data for training (default 0.8) + - val_ratio: float, ratio of data for validation (default 0.1) + - test_ratio: float, ratio of data for testing (default 0.1) + + Returns: + - train_df, val_df, test_df: pd.DataFrame, the split dataframes + """ + assert np.isclose(train_ratio + val_ratio + test_ratio, 1.0), "Ratios must sum to 1" + + df_sorted = df.sort_values("timestamp") + total_rows = len(df_sorted) + train_rows = int(total_rows * train_ratio) + val_rows = int(total_rows * val_ratio) + + train_df = df_sorted.iloc[:train_rows] + val_df = df_sorted.iloc[train_rows : train_rows + val_rows] + test_df = df_sorted.iloc[train_rows + val_rows :] + + return train_df, val_df, test_df + + +def split_by_last_interactions(df): + """ + Split the dataframe into train, validation, and test sets based on the last interactions. + Train: all interactions but last two + Validation: second to last interaction + Test: last interaction + + Parameters: + - df: pd.DataFrame, the input dataframe + + Returns: + - train_df, val_df, test_df: pd.DataFrame, the split dataframes + """ + df_sorted = df.sort_values(["user_id", "timestamp"]) + + # Get the last two interactions for each user + last_two = df_sorted.groupby("user_id").tail(2) + + # Split the last two interactions into validation and test + test_df = last_two.groupby("user_id").last().reset_index() + val_df = last_two.groupby("user_id").first().reset_index() + + # Remove the last two interactions from the original dataframe to create the train set + train_df = df_sorted[~df_sorted.index.isin(last_two.index)].reset_index(drop=True) + + return train_df, val_df, test_df + + +if __name__ == "__main__": + generate_sample_data() + + # df = generate_sample_data(return_df=True) + + # # Split the data by timestamp + # train_df, val_df, test_df = split_by_timestamp(df) + + # # Split the data by last interactions + # train_df2, val_df2, test_df2 = split_by_last_interactions(df) diff --git a/tests/test_data_utils/test_din_preprocessor.py b/tests/test_data_utils/test_din_preprocessor.py new file mode 100644 index 00000000..36523b77 --- /dev/null +++ b/tests/test_data_utils/test_din_preprocessor.py @@ -0,0 +1,87 @@ +import os +from pathlib import Path + +import numpy as np +import pandas as pd + +from pytorch_widedeep.preprocessing import DINPreprocessor + +full_path = os.path.realpath(__file__) +path = os.path.split(full_path)[0] + +save_dir = Path(path) / "data_for_rec_preprocessor" + +df_interactions = pd.read_csv(save_dir / "interactions_data.csv") + +df_interactions = df_interactions.sort_values(by=["user_id", "timestamp"]).reset_index( + drop=True +) + +cat_embed_cols = ["user_id", "age", "gender"] +continuous_cols = ["height", "weight"] + + +def user_behaviour_well_encoded( + input_df: pd.DataFrame, + din_preprocessor: DINPreprocessor, + X: np.ndarray, + max_seq_length: int, +): + + user_id: int = 5 + + encoded_user_id = ( # type: ignore + din_preprocessor.tab_preprocessor.label_encoder.encoding_dict["user_id"][ + user_id + ] + ) + + user_items = ( + input_df[input_df["user_id"] == user_id] + .groupby("user_id")["item_id"] + .agg(list) + .values[0] + )[:max_seq_length] + + encoded_items = [ + din_preprocessor.item_le.encoding_dict["item_id"][item] for item in user_items + ] + + rows = np.where( + X[:, din_preprocessor.din_columns_idx["user_id"]] == encoded_user_id + )[0] + X_user_id = X[rows][:1] + + item_seq_cols = din_preprocessor.user_behaviour_config[0] + item_seq_cols_idx = [din_preprocessor.din_columns_idx[col] for col in item_seq_cols] + + X_user_items = list(X_user_id[:, item_seq_cols_idx].astype(int)[0]) + + return X_user_items == encoded_items + + +def test_din_preprocessor(): + + din_preprocessor = DINPreprocessor( + user_id_col="user_id", + item_embed_col="item_id", + target_col="interaction", + action_col="interaction", + other_seq_embed_cols=["category", "price"], + cat_embed_cols=cat_embed_cols, + continuous_cols=continuous_cols, + cols_to_scale=continuous_cols, + max_seq_length=5, + ) + + X, y = din_preprocessor.fit_transform(df_interactions) + # 5 items + 5 actions + 2 * 5 other_seq_embed_cols + (1 target item + 1 + # target category + 1 target price) + 5 continuous_cols and + # cat_embed_cols + expected_n_cols = 5 + 5 + (2 * 5) + 3 + 5 + assert X.shape[1] == expected_n_cols + assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5) + + +if __name__ == "__main__": + test_din_preprocessor() From ecc6ea0532cf9a1cb97abee134f1b72cf8921d74 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 14 Oct 2024 20:34:07 +0100 Subject: [PATCH 05/24] ready to test the whole thing. I would have to add an extra text combining the preprocessor and the model --- .../preprocessing/din_preprocessor.py | 8 +- .../generate_interactions_data.py | 2 +- .../interactions_data.csv | 0 .../test_data_utils/test_din_preprocessor.py | 87 ---------- .../test_du_din_preprocessor.py | 162 ++++++++++++++++++ 5 files changed, 169 insertions(+), 90 deletions(-) rename tests/test_data_utils/{data_for_rec_preprocessor => interactions_df_for_rec_preprocessor}/interactions_data.csv (100%) delete mode 100644 tests/test_data_utils/test_din_preprocessor.py create mode 100644 tests/test_data_utils/test_du_din_preprocessor.py diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py index d28e6926..5819875f 100644 --- a/pytorch_widedeep/preprocessing/din_preprocessor.py +++ b/pytorch_widedeep/preprocessing/din_preprocessor.py @@ -179,7 +179,7 @@ def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: _df = self._pre_transform(df) df_w_sequences = self._build_sequences(_df) X_all = self._concatenate_features(df_w_sequences) - return X_all, np.array(df_w_sequences["target"].tolist()) + return X_all, np.array(df_w_sequences[self.target_col].tolist()) def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: self.fit(df) @@ -363,6 +363,10 @@ def _create_sequence( sequence.update( self._create_action_sequence(group, targets, i, target, drop_cols) ) + else: + sequence[self.target_col] = target + drop_cols.append(self.target_col) + if self.other_seq_embed_cols: sequence.update(self._create_other_seq_sequence(group, i, drop_cols)) @@ -392,7 +396,7 @@ def _create_action_sequence( else targets[i:-1] ) - return {"target": target, "actions_sequence": action_sequences} + return {self.target_col: target, "actions_sequence": action_sequences} def _create_other_seq_sequence( self, group: pd.DataFrame, i: int, drop_cols: List[str] diff --git a/tests/test_data_utils/generate_interactions_data.py b/tests/test_data_utils/generate_interactions_data.py index df128e37..7df32497 100644 --- a/tests/test_data_utils/generate_interactions_data.py +++ b/tests/test_data_utils/generate_interactions_data.py @@ -8,7 +8,7 @@ full_path = os.path.realpath(__file__) path = os.path.split(full_path)[0] -save_dir = Path(path) / "data_for_rec_preprocessor" +save_dir = Path(path) / "interactions_df_for_rec_preprocessor" def generate_sample_data(n_users=5, n_items=10, seed=42, return_df=False): diff --git a/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv b/tests/test_data_utils/interactions_df_for_rec_preprocessor/interactions_data.csv similarity index 100% rename from tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv rename to tests/test_data_utils/interactions_df_for_rec_preprocessor/interactions_data.csv diff --git a/tests/test_data_utils/test_din_preprocessor.py b/tests/test_data_utils/test_din_preprocessor.py deleted file mode 100644 index 36523b77..00000000 --- a/tests/test_data_utils/test_din_preprocessor.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from pathlib import Path - -import numpy as np -import pandas as pd - -from pytorch_widedeep.preprocessing import DINPreprocessor - -full_path = os.path.realpath(__file__) -path = os.path.split(full_path)[0] - -save_dir = Path(path) / "data_for_rec_preprocessor" - -df_interactions = pd.read_csv(save_dir / "interactions_data.csv") - -df_interactions = df_interactions.sort_values(by=["user_id", "timestamp"]).reset_index( - drop=True -) - -cat_embed_cols = ["user_id", "age", "gender"] -continuous_cols = ["height", "weight"] - - -def user_behaviour_well_encoded( - input_df: pd.DataFrame, - din_preprocessor: DINPreprocessor, - X: np.ndarray, - max_seq_length: int, -): - - user_id: int = 5 - - encoded_user_id = ( # type: ignore - din_preprocessor.tab_preprocessor.label_encoder.encoding_dict["user_id"][ - user_id - ] - ) - - user_items = ( - input_df[input_df["user_id"] == user_id] - .groupby("user_id")["item_id"] - .agg(list) - .values[0] - )[:max_seq_length] - - encoded_items = [ - din_preprocessor.item_le.encoding_dict["item_id"][item] for item in user_items - ] - - rows = np.where( - X[:, din_preprocessor.din_columns_idx["user_id"]] == encoded_user_id - )[0] - X_user_id = X[rows][:1] - - item_seq_cols = din_preprocessor.user_behaviour_config[0] - item_seq_cols_idx = [din_preprocessor.din_columns_idx[col] for col in item_seq_cols] - - X_user_items = list(X_user_id[:, item_seq_cols_idx].astype(int)[0]) - - return X_user_items == encoded_items - - -def test_din_preprocessor(): - - din_preprocessor = DINPreprocessor( - user_id_col="user_id", - item_embed_col="item_id", - target_col="interaction", - action_col="interaction", - other_seq_embed_cols=["category", "price"], - cat_embed_cols=cat_embed_cols, - continuous_cols=continuous_cols, - cols_to_scale=continuous_cols, - max_seq_length=5, - ) - - X, y = din_preprocessor.fit_transform(df_interactions) - # 5 items + 5 actions + 2 * 5 other_seq_embed_cols + (1 target item + 1 - # target category + 1 target price) + 5 continuous_cols and - # cat_embed_cols - expected_n_cols = 5 + 5 + (2 * 5) + 3 + 5 - assert X.shape[1] == expected_n_cols - assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5) - - -if __name__ == "__main__": - test_din_preprocessor() diff --git a/tests/test_data_utils/test_du_din_preprocessor.py b/tests/test_data_utils/test_du_din_preprocessor.py new file mode 100644 index 00000000..40ea8dfd --- /dev/null +++ b/tests/test_data_utils/test_du_din_preprocessor.py @@ -0,0 +1,162 @@ +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from pytorch_widedeep.preprocessing import DINPreprocessor + +full_path = os.path.realpath(__file__) +path = os.path.split(full_path)[0] + +save_dir = Path(path) / "interactions_df_for_rec_preprocessor" + +df_interactions = pd.read_csv(save_dir / "interactions_data.csv") + +df_interactions = df_interactions.sort_values(by=["user_id", "timestamp"]).reset_index( + drop=True +) + +cat_embed_cols = ["user_id", "age", "gender"] +continuous_cols = ["height", "weight"] + + +def user_behaviour_well_encoded( + input_df: pd.DataFrame, + din_preprocessor: DINPreprocessor, + X: np.ndarray, + max_seq_length: int, +): + + user_id: int = 5 + + encoded_user_id = ( # type: ignore + din_preprocessor.tab_preprocessor.label_encoder.encoding_dict["user_id"][ + user_id + ] + ) + + user_items = ( + input_df[input_df["user_id"] == user_id] + .groupby("user_id")["item_id"] + .agg(list) + .values[0] + )[:max_seq_length] + + encoded_items = [ + din_preprocessor.item_le.encoding_dict["item_id"][item] for item in user_items + ] + + rows = np.where( + X[:, din_preprocessor.din_columns_idx["user_id"]] == encoded_user_id + )[0] + X_user_id = X[rows][:1] + + item_seq_cols = din_preprocessor.user_behaviour_config[0] + item_seq_cols_idx = [din_preprocessor.din_columns_idx[col] for col in item_seq_cols] + + X_user_items = list(X_user_id[:, item_seq_cols_idx].astype(int)[0]) + + return X_user_items == encoded_items + + +def test_din_preprocessor(): + + din_preprocessor = DINPreprocessor( + user_id_col="user_id", + item_embed_col="item_id", + target_col="interaction", + action_col="interaction", + other_seq_embed_cols=["category", "price"], + cat_embed_cols=cat_embed_cols, + continuous_cols=continuous_cols, + cols_to_scale=continuous_cols, + max_seq_length=5, + ) + + X, y = din_preprocessor.fit_transform(df_interactions) + # 5 items + 5 actions + 2 * 5 other_seq_embed_cols + (1 target item + 1 + # target category + 1 target price) + 5 continuous_cols and + # cat_embed_cols + expected_n_cols = 5 + 5 + (2 * 5) + 3 + 5 + assert X.shape[1] == expected_n_cols + assert din_preprocessor.is_fitted + assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5) + + +@pytest.mark.parametrize( + "with_cat_embed_cols, with_continuous_cols, with_action_col, with_other_seq_embed_cols", + [ + (True, True, True, True), + (True, True, True, False), + (True, True, False, True), + (True, True, False, False), + (True, False, True, True), + (True, False, True, False), + (True, False, False, True), + (True, False, False, False), + (False, True, True, True), + (False, True, True, False), + (False, True, False, True), + (False, True, False, False), + (False, False, True, True), + (False, False, True, False), + (False, False, False, True), + (False, False, False, False), + ], +) +def test_din_preprocessor_diff_input_params( + with_cat_embed_cols, + with_continuous_cols, + with_action_col, + with_other_seq_embed_cols, +): + + max_seq_length = 5 + + din_preprocessor = DINPreprocessor( + user_id_col="user_id", + item_embed_col="item_id", + target_col="interaction", + action_col="interaction" if with_action_col else None, + other_seq_embed_cols=( + ["category", "price"] if with_other_seq_embed_cols else None + ), + cat_embed_cols=cat_embed_cols if with_cat_embed_cols else None, + continuous_cols=continuous_cols if with_continuous_cols else None, + cols_to_scale=continuous_cols if with_continuous_cols else None, + max_seq_length=max_seq_length, + ) + + X, y = din_preprocessor.fit_transform(df_interactions) + + expected_n_cols = max_seq_length + 1 + + assert din_preprocessor.user_behaviour_config[0] == [ + f"item_{i+1}" for i in range(max_seq_length) + ] + + if with_action_col: + expected_n_cols += 5 + assert din_preprocessor.action_seq_config[0] == [ + f"action_{i+1}" for i in range(max_seq_length) + ] + + if with_other_seq_embed_cols: + for i, col in enumerate(["category", "price"]): + expected_n_cols += 5 + 1 + assert din_preprocessor.other_seq_config[i][0] == [ + f"{col}_{i+1}" for i in range(max_seq_length) + ] + + if with_cat_embed_cols: + expected_n_cols += len(cat_embed_cols) + + if with_continuous_cols: + expected_n_cols += len(continuous_cols) + + assert din_preprocessor.is_fitted + assert X.shape[1] == expected_n_cols + if with_cat_embed_cols: + assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5) From 112d8c317b804fe74c041de3521a8cc11bf2a627 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 14 Oct 2024 21:34:50 +0100 Subject: [PATCH 06/24] Added test for the full process. Ready to test the whole library --- pytorch_widedeep/models/rec/din.py | 3 +- .../test_du_din_preprocessor.py | 42 +++++++++++++++++++ tests/test_rec/test_din.py | 2 +- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pytorch_widedeep/models/rec/din.py b/pytorch_widedeep/models/rec/din.py index 293f249e..571717a1 100644 --- a/pytorch_widedeep/models/rec/din.py +++ b/pytorch_widedeep/models/rec/din.py @@ -405,7 +405,8 @@ def forward(self, X: Tensor) -> Tensor: [ self.other_seq_cols_embed[col]._get_embeddings(X_other_seq[col]) for col in self.other_seq_cols_indexes.keys() - ] + ], + dim=-1, ).sum(1) deep_out = torch.cat([deep_out, other_seq_embed], dim=1) diff --git a/tests/test_data_utils/test_du_din_preprocessor.py b/tests/test_data_utils/test_du_din_preprocessor.py index 40ea8dfd..b56cd04d 100644 --- a/tests/test_data_utils/test_du_din_preprocessor.py +++ b/tests/test_data_utils/test_du_din_preprocessor.py @@ -5,6 +5,9 @@ import pandas as pd import pytest +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import WideDeep +from pytorch_widedeep.models.rec import DeepInterestNetwork from pytorch_widedeep.preprocessing import DINPreprocessor full_path = os.path.realpath(__file__) @@ -160,3 +163,42 @@ def test_din_preprocessor_diff_input_params( assert X.shape[1] == expected_n_cols if with_cat_embed_cols: assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5) + + +def test_din_full_process_w_processor(): + + din_preprocessor = DINPreprocessor( + user_id_col="user_id", + item_embed_col="item_id", + target_col="interaction", + action_col="interaction", + other_seq_embed_cols=["category", "price"], + cat_embed_cols=cat_embed_cols, + continuous_cols=continuous_cols, + cols_to_scale=continuous_cols, + max_seq_length=5, + ) + + X, y = din_preprocessor.fit_transform(df_interactions) + + din = DeepInterestNetwork( + column_idx=din_preprocessor.din_columns_idx, + user_behavior_confiq=din_preprocessor.user_behaviour_config, + target_item_col="target_item", + action_seq_config=din_preprocessor.action_seq_config, + other_seq_cols_confiq=din_preprocessor.other_seq_config, + cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore + continuous_cols=din_preprocessor.tab_preprocessor.continuous_cols, + mlp_hidden_dims=[16, 8], + ) + + model = WideDeep(deeptabular=din) + + trainer = Trainer(model, objective="binary", verbose=0) + + trainer.fit(X_tab=X, target=y, n_epochs=2) + + preds = trainer.predict(X_tab=X) + + assert preds.shape[0] == X.shape[0] + assert trainer.history is not None and "train_loss" in trainer.history diff --git a/tests/test_rec/test_din.py b/tests/test_rec/test_din.py index 92724791..dcdbde86 100644 --- a/tests/test_rec/test_din.py +++ b/tests/test_rec/test_din.py @@ -79,8 +79,8 @@ def test_din_with_params( din = DeepInterestNetwork( column_idx=column_idx, - target_item_col="target_item", user_behavior_confiq=item_seq_config, + target_item_col="target_item", action_seq_config=item_purchased_seq_config, other_seq_cols_confiq=item_category_seq_config, cat_embed_input=cat_embed_input, From adc77d7bf87fe80b94e4041058b29d3264c533ae Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 15 Oct 2024 10:42:08 +0100 Subject: [PATCH 07/24] All test passed. Need to increase coverage a bit --- .../sources/pytorch-widedeep/preprocessing.md | 2 + pytorch_widedeep/models/rec/din.py | 7 +- .../preprocessing/din_preprocessor.py | 124 ++++-------------- pytorch_widedeep/tab2vec.py | 2 +- 4 files changed, 34 insertions(+), 101 deletions(-) diff --git a/mkdocs/sources/pytorch-widedeep/preprocessing.md b/mkdocs/sources/pytorch-widedeep/preprocessing.md index ab608846..520dd914 100644 --- a/mkdocs/sources/pytorch-widedeep/preprocessing.md +++ b/mkdocs/sources/pytorch-widedeep/preprocessing.md @@ -21,6 +21,8 @@ available: one for the case when no Hugging Face model is used ::: pytorch_widedeep.preprocessing.image_preprocessor.ImagePreprocessor +::: pytorch_widedeep.preprocessing.din_preprocessor.DINPreprocessor + ## Chunked versions diff --git a/pytorch_widedeep/models/rec/din.py b/pytorch_widedeep/models/rec/din.py index 571717a1..3c42a975 100644 --- a/pytorch_widedeep/models/rec/din.py +++ b/pytorch_widedeep/models/rec/din.py @@ -23,10 +23,9 @@ class DeepInterestNetwork(BaseWDModelComponent): sequential columns and will be treated as standard tabular data. This model requires some specific data preparation that allows for quite a - lot of flexibility. Therefore, this library does not currently include a - preprocessor designed specifically for this model. Please, see the - example 'movielens_din.py' in the examples folder to understand the data - preparation process. + lot of flexibility. Therefore, I have included a preprocessor + (`DINPreprocessor`) in the preprocessing module that will take care of + the data preparation. Parameters ---------- diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py index 5819875f..f66bbdf7 100644 --- a/pytorch_widedeep/preprocessing/din_preprocessor.py +++ b/pytorch_widedeep/preprocessing/din_preprocessor.py @@ -5,11 +5,6 @@ import numpy as np import pandas as pd -# from pytorch_widedeep.models import WideDeep -# from pytorch_widedeep.metrics import Accuracy -# from pytorch_widedeep.datasets import load_movielens100k -# from pytorch_widedeep.training import Trainer -# from pytorch_widedeep.models.rec.din import DeepInterestNetwork from pytorch_widedeep.utils.text_utils import pad_sequences from pytorch_widedeep.utils.general_utils import alias from pytorch_widedeep.utils.deeptabular_utils import LabelEncoder @@ -100,6 +95,32 @@ class DINPreprocessor(BasePreprocessor): Number of unique values in each other sequence column. other_seq_config : List[Tuple[List[str], int, int]] Configuration for other sequence columns. + + Examples: + --------- + >>> import pandas as pd + >>> from pytorch_widedeep.preprocessing import DINPreprocessor + >>> data = { + ... 'user_id': [1, 1, 1, 2, 2, 2, 3, 3, 3], + ... 'item_id': [101, 102, 103, 101, 103, 104, 102, 103, 104], + ... 'timestamp': [1, 2, 3, 1, 2, 3, 1, 2, 3], + ... 'category': ['A', 'B', 'A', 'B', 'A', 'C', 'B', 'A', 'C'], + ... 'price': [10.5, 15.0, 12.0, 10.5, 12.0, 20.0, 15.0, 12.0, 20.0], + ... 'rating': [0, 1, 0, 1, 0, 1, 0, 1, 0] + ... } + >>> df = pd.DataFrame(data) + >>> din_preprocessor = DINPreprocessor( + ... user_id_col='user_id', + ... item_embed_col='item_id', + ... target_col='rating', + ... max_seq_length=2, + ... action_col='rating', + ... other_seq_embed_cols=['category'], + ... cat_embed_cols=['user_id'], + ... continuous_cols=['price'], + ... cols_to_scale=['price'] + ... ) + >>> X, y = din_preprocessor.fit_transform(df) """ @alias("item_embed_col", ["item_id_col"]) @@ -313,9 +334,9 @@ def _pre_transform(self, df: pd.DataFrame) -> pd.DataFrame: return df def _build_sequences(self, df: pd.DataFrame) -> pd.DataFrame: - # TO DO: do something to avoid warning here + # df.columns.tolist() -> to avoid annoying pandas warning return ( - df.groupby(self.user_id_col) + df.groupby(self.user_id_col)[df.columns.tolist()] # type: ignore[index] .apply(self._group_sequences) .reset_index(drop=True) ) @@ -490,92 +511,3 @@ def _concatenate_other_seq_features( X_other_seq_arrays.append(np.array(other_seq_target[col]).reshape(-1, 1)) return np.concatenate([X_all] + X_other_seq_arrays, axis=1) - - -# if __name__ == "__main__": - -# def clean_genre_list(genre_list): -# return "_".join( -# sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list]) -# ) - -# data, users, items = load_movielens100k(as_frame=True) - -# list_of_genres = [ -# "unknown", -# "Action", -# "Adventure", -# "Animation", -# "Children's", -# "Comedy", -# "Crime", -# "Documentary", -# "Drama", -# "Fantasy", -# "Film-Noir", -# "Horror", -# "Musical", -# "Mystery", -# "Romance", -# "Sci-Fi", -# "Thriller", -# "War", -# "Western", -# ] - -# assert ( -# isinstance(items, pd.DataFrame) -# and isinstance(data, pd.DataFrame) -# and isinstance(users, pd.DataFrame) -# ) -# items["genre_list"] = items[list_of_genres].apply( -# lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1 -# ) - -# items["genre_list"] = items["genre_list"].apply(clean_genre_list) - -# df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id") -# df = pd.merge( -# df, -# users[["user_id", "age", "gender", "occupation"]], -# on="user_id", -# ) - -# df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0) - -# df = df.sort_values(by=["timestamp"]).reset_index(drop=True) -# df = df.drop("timestamp", axis=1) - -# din_preprocessor = DINPreprocessor( -# user_id_col="user_id", -# target_col="rating", -# item_embed_col=("movie_id", 16), -# max_seq_length=5, -# action_col="rating", -# other_seq_embed_cols=[("genre_list", 16)], -# cat_embed_cols=["user_id", "age", "gender", "occupation"], -# ) - -# X, y = din_preprocessor.fit_transform(df) - -# din = DeepInterestNetwork( -# column_idx=din_preprocessor.din_columns_idx, -# user_behavior_confiq=din_preprocessor.user_behaviour_config, -# action_seq_config=din_preprocessor.action_seq_config, -# other_seq_cols_confiq=din_preprocessor.other_seq_config, -# cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined] -# mlp_hidden_dims=[128, 64], -# ) - -# # And from here on, everything is standard -# model = WideDeep(deeptabular=din) - -# trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()]) - -# # in the real world you would have to split the data into train, val and test -# trainer.fit( -# X_tab=X, -# target=y, -# n_epochs=5, -# batch_size=512, -# ) diff --git a/pytorch_widedeep/tab2vec.py b/pytorch_widedeep/tab2vec.py index 5ba2ab48..c10ddb56 100644 --- a/pytorch_widedeep/tab2vec.py +++ b/pytorch_widedeep/tab2vec.py @@ -77,7 +77,7 @@ class Tab2Vec: >>> # ...train the model... >>> >>> # vectorise the dataframe - >>> t2v = Tab2Vec(tab_preprocessor, model) + >>> t2v = Tab2Vec(tab_preprocessor, model, device="cpu") >>> X_vec = t2v.transform(df_t2v) """ From f076867711b34d065b250fb8223adde29c0aa35a Mon Sep 17 00:00:00 2001 From: Javier Date: Fri, 18 Oct 2024 23:02:58 +0100 Subject: [PATCH 08/24] Added tests for the new metrics. Still need to test the full functioning with the Trainer --- pytorch_widedeep/metrics.py | 417 ++++++++++++++++++++- pytorch_widedeep/training/_base_trainer.py | 33 +- pytorch_widedeep/training/trainer.py | 45 ++- tests/test_metrics/test_ranking_metrics.py | 188 ++++++++++ 4 files changed, 661 insertions(+), 22 deletions(-) create mode 100644 tests/test_metrics/test_ranking_metrics.py diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index a2947380..e4b871ed 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -2,7 +2,8 @@ import torch from torchmetrics import Metric as TorchMetric -from pytorch_widedeep.wdtypes import Dict, List, Union, Tensor +from pytorch_widedeep.wdtypes import Dict, List, Union, Tensor, Optional +from pytorch_widedeep.utils.general_utils import alias class Metric(object): @@ -394,3 +395,417 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: y_true_avg = self.y_true_sum / self.num_examples self.denominator += ((y_true - y_true_avg) ** 2).sum().item() return np.array((1 - (self.numerator / self.denominator))) + + +def reshape_1d_to_2d(tensor: Tensor, n_columns: int) -> Tensor: + if tensor.dim() != 1: + raise ValueError("Input tensor must be 1-dimensional") + if tensor.size(0) % n_columns != 0: + raise ValueError( + f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})" + ) + n_rows = tensor.size(0) // n_columns + return tensor.reshape(n_rows, n_columns) + + +class NDCG_at_k(Metric): + r""" + Normalized Discounted Cumulative Gain (NDCG) at k. + + Parameters + ---------- + n_cols: int, default = 10 + Number of columns in the input tensors. This parameter is neccessary + because the input tensors are reshaped to 2D tensors. n_cols is the + number of columns in the reshaped tensor. Alias for this parameter + are: 'n_items', 'n_items_per_query', + 'n_items_per_id', 'n_items_per_user' + k: int, Optional, default = None + Number of top items to consider. It must be less than or equal to n_cols. + If is None, k will be equal to n_cols. + eps: float, default = 1e-8 + Small value to avoid division by zero. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.metrics import NDCG_at_k + >>> + >>> ndcg = NDCG_at_k(k=10) + >>> y_pred = torch.rand(100, 10) + >>> y_true = torch.randint(100, 10) + >>> score = ndcg(y_pred, y_true) + """ + + @alias( + "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"] + ) + def __init__(self, n_cols: int = 10, k: Optional[int] = None, eps: float = 1e-8): + super(NDCG_at_k, self).__init__() + + if k is not None and k > n_cols: + raise ValueError( + f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}" + ) + + self.n_cols = n_cols + self.k = k if k is not None else n_cols + self.eps = eps + self._name = f"ndcg@{k}" + self.reset() + + def reset(self): + self.sum_ndcg = 0.0 + self.count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + device = y_pred.device + + y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + + batch_size = y_true_reshaped.shape[0] + + _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) + top_k_relevance = y_true_reshaped.gather(1, top_k_indices) + discounts = 1.0 / torch.log2( + torch.arange(2, top_k_relevance.shape[1] + 2, device=device) + ) + + dcg = (torch.pow(2, top_k_relevance) - 1) * discounts.unsqueeze(0) + dcg = dcg.sum(dim=1) + + sorted_relevance, _ = torch.sort(y_true_reshaped, dim=1, descending=True) + ideal_relevance = sorted_relevance[:, : self.k] + + idcg = (torch.pow(2, ideal_relevance) - 1) * discounts[ + : ideal_relevance.shape[1] + ].unsqueeze(0) + + idcg = idcg.sum(dim=1) + ndcg = dcg / (idcg + self.eps) + + self.sum_ndcg += ndcg.sum().item() + self.count += batch_size + + return np.array(self.sum_ndcg / max(self.count, 1)) + + +class BinaryNDCG_at_k(Metric): + r""" + Normalized Discounted Cumulative Gain (NDCG) at k for binary relevance. + + Parameters + ---------- + n_cols: int, default = 10 + Number of columns in the input tensors. This parameter is neccessary + because the input tensors are reshaped to 2D tensors. n_cols is the + number of columns in the reshaped tensor. Alias for this parameter + are: 'n_items', 'n_items_per_query', + 'n_items_per_id', 'n_items_per_user' + k: int, Optional, default = None + Number of top items to consider. It must be less than or equal to n_cols. + If is None, k will be equal to n_cols. + eps: float, default = 1e-8 + Small value to avoid division by zero. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.metrics import BinaryNDCG_at_k + >>> + >>> ndcg = BinaryNDCG_at_k(k=10) + >>> y_pred = torch.rand(100, 10) + >>> y_true = torch.randint(0, 2, (100, 10)) + >>> score = ndcg(y_pred, y_true) + """ + + @alias( + "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"] + ) + def __init__(self, n_cols: int = 10, k: Optional[int] = None, eps: float = 1e-8): + super(BinaryNDCG_at_k, self).__init__() + + if k is not None and k > n_cols: + raise ValueError( + f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}" + ) + + self.n_cols = n_cols + self.k = k if k is not None else n_cols + self.eps = eps + self._name = f"binary_ndcg@{k}" + self.reset() + + def reset(self): + self.sum_ndcg = 0.0 + self.count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + device = y_pred.device + + y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + + batch_size = y_pred_reshaped.shape[0] + + _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) + top_k_mask = torch.zeros_like(y_pred_reshaped, dtype=torch.bool).scatter_( + 1, top_k_indices, 1 + ) + + _discounts = 1.0 / torch.log2( + torch.arange(2, self.k + 2, device=device).float() + ) + expanded_discounts = _discounts.repeat(1, batch_size) + discounts = torch.zeros_like(top_k_mask, dtype=torch.float) + discounts[top_k_mask] = expanded_discounts + + dcg = (y_true_reshaped * top_k_mask * discounts).sum(dim=1) + n_relevant = torch.minimum( + y_true_reshaped.sum(dim=1), torch.tensor(self.k, device=device) + ).int() + ideal_discounts = 1.0 / torch.log2( + torch.arange(2, y_true_reshaped.shape[1] + 2, device=device).float() + ) + + idcg = torch.zeros(batch_size, device=device) + for i in range(batch_size): + idcg[i] = ideal_discounts[: n_relevant[i]].sum() + + ndcg = dcg / (idcg + self.eps) + + self.sum_ndcg += ndcg.sum().item() + self.count += batch_size + + return np.array(self.sum_ndcg / max(self.count, 1)) + + +class MAP_at_k(Metric): + r""" + Mean Average Precision (MAP) at k. + + Parameters + ---------- + n_cols: int, default = 10 + Number of columns in the input tensors. This parameter is neccessary + because the input tensors are reshaped to 2D tensors. n_cols is the + number of columns in the reshaped tensor. Alias for this parameter + are: 'n_items', 'n_items_per_query', 'n_items_per_id', 'n_items_per_user' + k: int, Optional, default = None + Number of top items to consider. It must be less than or equal to n_cols. + If is None, k will be equal to n_cols. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.metrics import MAP_at_k + >>> + >>> map_at_k = MAP_at_k(k=10) + >>> y_pred = torch.rand(100, 10) + >>> y_true = torch.randint(0, 2, (100, 10)) + >>> score = map_at_k(y_pred, y_true) + """ + + def __init__(self, n_cols: int = 10, k: Optional[int] = None): + super(MAP_at_k, self).__init__() + + if k is not None and k > n_cols: + raise ValueError( + f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}" + ) + + self.n_cols = n_cols + self.k = k if k is not None else n_cols + self._name = f"map@{k}" + self.reset() + + def reset(self): + self.sum_avg_precision = 0.0 + self.count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + + y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + + batch_size = y_pred_reshaped.shape[0] + _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) + batch_relevance = y_true_reshaped.gather(1, top_k_indices) + cumsum_relevance = torch.cumsum(batch_relevance, dim=1) + precision_at_i = cumsum_relevance / torch.arange( + 1, self.k + 1, device=y_pred_reshaped.device + ).float().unsqueeze(0) + avg_precision = (precision_at_i * batch_relevance).sum(dim=1) / torch.clamp( + y_true_reshaped.sum(dim=1), min=1 + ) + self.sum_avg_precision += avg_precision.sum().item() + self.count += batch_size + return np.array(self.sum_avg_precision / max(self.count, 1)) + + +class HitRatio_at_k(Metric): + r""" + Hit Ratio (HR) at k. + + Parameters + ---------- + n_cols: int, default = 10 + Number of columns in the input tensors. This parameter is neccessary + because the input tensors are reshaped to 2D tensors. n_cols is the + number of columns in the reshaped tensor. Alias for this parameter + are: 'n_items', 'n_items_per_query', 'n_items_per_id', 'n_items_per_user' + k: int, Optional, default = None + Number of top items to consider. It must be less than or equal to n_cols. + If is None, k will be equal to n_cols. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.metrics import HitRatio_at_k + >>> + >>> hr_at_k = HitRatio_at_k(k=10) + >>> y_pred = torch.rand(100, 10) + >>> y_true = torch.randint(0, 2, (100, 10)) + >>> score = hr_at_k(y_pred, y_true + """ + + def __init__(self, n_cols: int = 10, k: Optional[int] = None): + super(HitRatio_at_k, self).__init__() + + if k is not None and k > n_cols: + raise ValueError( + f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}" + ) + + self.n_cols = n_cols + self.k = k if k is not None else n_cols + self._name = f"hr@{k}" + self.reset() + + def reset(self): + self.sum_hr = 0.0 + self.count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + batch_size = y_pred_reshaped.shape[0] + _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) + batch_relevance = y_true_reshaped.gather(1, top_k_indices) + hit = (batch_relevance.sum(dim=1) > 0).float() + self.sum_hr += hit.sum().item() + self.count += batch_size + return np.array(self.sum_hr / max(self.count, 1)) + + +class Precision_at_k(Metric): + r""" + Precision at k. + + Parameters + ---------- + n_cols: int, default = 10 + Number of columns in the input tensors. This parameter is neccessary + because the input tensors are reshaped to 2D tensors. n_cols is the + number of columns in the reshaped tensor. Alias for this parameter + are: 'n_items', 'n_items_per_query', + 'n_items_per_id', 'n_items_per + k: int, Optional, default = None + Number of top items to consider. It must be less than or equal to n_cols. + If is None, k will be equal to n_cols. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.metrics import Precision_at_k + >>> + >>> prec_at_k = Precision_at_k(k=10) + >>> y_pred = torch.rand(100, 10) + >>> y_true = torch.randint(0, 2, (100, 10)) + >>> score = prec_at_k(y_pred, y_true) + """ + + def __init__(self, n_cols: int = 10, k: Optional[int] = None): + super(Precision_at_k, self).__init__() + + if k is not None and k > n_cols: + raise ValueError( + f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}" + ) + + self.n_cols = n_cols + self.k = k if k is not None else n_cols + self._name = f"precision@{k}" + self.reset() + + def reset(self): + self.sum_precision = 0.0 + self.count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + batch_size = y_pred_reshaped.shape[0] + _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) + batch_relevance = y_true_reshaped.gather(1, top_k_indices) + precision = batch_relevance.sum(dim=1) / self.k + self.sum_precision += precision.sum().item() + self.count += batch_size + return np.array(self.sum_precision / max(self.count, 1)) + + +class Recall_at_k(Metric): + r""" + Recall at k. + + Parameters + ---------- + k: int, default = 10 + Number of top items to consider. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.metrics import Recall_at_k + >>> + >>> rec_at_k = Recall_at_k(k=10) + >>> y_pred = torch.rand(100, 10) + >>> y_true = torch.randint(0, 2, (100, 10)) + >>> score = rec_at_k(y_pred, y_true) + """ + + def __init__(self, n_cols: int = 10, k: Optional[int] = None): + super(Recall_at_k, self).__init__() + + if k is not None and k > n_cols: + raise ValueError( + f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}" + ) + + self.n_cols = n_cols + self.k = k if k is not None else n_cols + self._name = f"recall@{k}" + self.reset() + + def reset(self): + self.sum_recall = 0.0 + self.count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + batch_size = y_pred_reshaped.shape[0] + _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) + batch_relevance = y_true_reshaped.gather(1, top_k_indices) + recall = batch_relevance.sum(dim=1) / torch.clamp( + y_true_reshaped.sum(dim=1), min=1 + ) + self.sum_recall += recall.sum().item() + self.count += batch_size + return np.array(self.sum_recall / max(self.count, 1)) + + +RankingMetrics = Union[ + BinaryNDCG_at_k, MAP_at_k, HitRatio_at_k, Precision_at_k, Recall_at_k +] diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py index 45cb2e9e..21d8985b 100644 --- a/pytorch_widedeep/training/_base_trainer.py +++ b/pytorch_widedeep/training/_base_trainer.py @@ -10,7 +10,7 @@ from torchmetrics import Metric as TorchMetric from torch.optim.lr_scheduler import ReduceLROnPlateau -from pytorch_widedeep.metrics import Metric, MultipleMetrics +from pytorch_widedeep.metrics import Metric, RankingMetrics, MultipleMetrics from pytorch_widedeep.wdtypes import ( Any, Dict, @@ -285,13 +285,14 @@ def _set_transforms(transforms): else: return None - # this needs type fixing to adjust for the fact that the main class can - # take an 'object', a non-instastiated Class, so, should be something like: - # callbacks: Optional[List[Union[object, Callback]]] in all places + # TO DO: this needs type fixing to adjust for the fact that the main class + # can take an 'object', a non-instastiated Class, so, should be something + # like: callbacks: Optional[List[Union[object, Callback]]] in all places def _set_callbacks_and_metrics( self, callbacks: Any, - metrics: Any, + metrics: Any, # Union[List[Metric], List[TorchMetric]], + eval_metrics: Optional[Any] = None, # Union[List[Metric], List[TorchMetric]], ): self.callbacks: List = [History(), LRShedulerCallback()] if callbacks is not None: @@ -301,9 +302,31 @@ def _set_callbacks_and_metrics( self.callbacks.append(callback) if metrics is not None: self.metric = MultipleMetrics(metrics) + # assert that no ranking metric is used during training assertion + # here so all metrics are instanciated + assert not any( + [isinstance(m, RankingMetrics) for m in self.metric._metrics] # type: ignore[arg-type, misc] + ), "Currently, ranking metrics are not supported during training" + self.callbacks += [MetricCallback(self.metric)] else: self.metric = None + if eval_metrics is not None: + self.eval_metric = MultipleMetrics(eval_metrics) + # assert that if any of the metrics is a ranking metric, all metrics + # must be ranking metrics + if any( + [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc] + ): + assert all( + [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc] + ), ( + "All eval metrics must be ranking metrics if any of the eval" + " metrics is a ranking metric" + ) + self.callbacks += [MetricCallback(self.eval_metric)] + else: + self.eval_metric = None self.callback_container = CallbackContainer(self.callbacks) self.callback_container.set_model(self.model) self.callback_container.set_trainer(self) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 8ca7d459..cff69bfd 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -228,6 +228,7 @@ class Trainer(BaseTrainer): "objective", ["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"], ) + @alias("metrics", ["train_metrics"]) def __init__( self, model: WideDeep, @@ -245,6 +246,7 @@ def __init__( transforms: Optional[List[Transforms]] = None, callbacks: Optional[List[Callback]] = None, metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None, + eval_metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None, verbose: int = 1, seed: int = 1, **kwargs, @@ -259,6 +261,7 @@ def __init__( transforms=transforms, callbacks=callbacks, metrics=metrics, + eval_metrics=eval_metrics, verbose=verbose, seed=seed, **kwargs, @@ -600,7 +603,7 @@ def predict_uncertainty( # type: ignore[return] X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None, X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None, batch_size: Optional[int] = None, - uncertainty_granularity=1000, + uncertainty_granularity: int = 1000, ) -> np.ndarray: r"""Returns the predicted ucnertainty of the model for the test dataset using a Monte Carlo method during which dropout layers are activated @@ -930,10 +933,10 @@ def _train_step( if self.model.is_tabnet: loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1] - score = self._get_score(y_pred[0], y) + score = self._get_score(y_pred[0], y, is_train=True) else: loss = self.loss_fn(y_pred, y) - score = self._get_score(y_pred, y) + score = self._get_score(y_pred, y, is_train=True) loss.backward() self.optimizer.step() @@ -967,9 +970,9 @@ def _eval_step( y_pred = self.model(X) if self.model.is_tabnet: loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1] - score = self._get_score(y_pred[0], y) + score = self._get_score(y_pred[0], y, is_train=False) else: - score = self._get_score(y_pred, y) + score = self._get_score(y_pred, y, is_train=False) loss = self.loss_fn(y_pred, y) self.valid_running_loss += loss.item() @@ -978,20 +981,30 @@ def _eval_step( self.model.train() return score, avg_loss - def _get_score(self, y_pred, y): - if self.metric is not None: + def _get_score( + self, y_pred: Tensor, y: Tensor, is_train: bool + ) -> Optional[Dict[str, float]]: + + score = None + metric = None + + if hasattr(self, "metric") and not hasattr(self, "eval_metric"): + metric = self.metric + elif hasattr(self, "metric") and hasattr(self, "eval_metric"): + metric = self.metric if is_train else self.eval_metric + elif not hasattr(self, "metric") and hasattr(self, "eval_metric"): + metric = None if is_train else self.eval_metric + + if metric is not None: if self.method == "regression": - score = self.metric(y_pred, y) + score = metric(y_pred, y) if self.method == "binary": - score = self.metric(torch.sigmoid(y_pred), y) + score = metric(torch.sigmoid(y_pred), y) if self.method == "qregression": - score = self.metric(y_pred, y) + score = metric(y_pred, y) if self.method == "multiclass": - score = self.metric(F.softmax(y_pred, dim=1), y) - # TO DO: handle multitarget - return score - else: - return None + score = metric(F.softmax(y_pred, dim=1), y) + return score def _predict( # type: ignore[override, return] # noqa: C901 self, @@ -1001,7 +1014,7 @@ def _predict( # type: ignore[override, return] # noqa: C901 X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None, X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None, batch_size: Optional[int] = None, - uncertainty_granularity=1000, + uncertainty_granularity: int = 1000, uncertainty: bool = False, ) -> List: r"""Private method to avoid code repetition in predict and diff --git a/tests/test_metrics/test_ranking_metrics.py b/tests/test_metrics/test_ranking_metrics.py new file mode 100644 index 00000000..b08ecb4a --- /dev/null +++ b/tests/test_metrics/test_ranking_metrics.py @@ -0,0 +1,188 @@ +# Thanks claude 3.5 for the test cases. +import numpy as np +import torch +import pytest + +from pytorch_widedeep.metrics import ( + MAP_at_k, + NDCG_at_k, + Recall_at_k, + HitRatio_at_k, + Precision_at_k, + BinaryNDCG_at_k, +) + + +@pytest.fixture +def setup_data(): + y_pred = torch.tensor( + [ + [0.7, 0.9, 0.5, 0.8, 0.6], + [0.3, 0.1, 0.5, 0.2, 0.4], + ] + ) + y_true = torch.tensor( + [ + [0, 1, 0, 1, 1], + [1, 1, 0, 0, 0], + ] + ) + return y_pred, y_true + + +@pytest.fixture +def setup_data_ndcg(): + y_pred = torch.tensor( + [ + [0.7, 0.9, 0.5, 0.8, 0.6], + [0.3, 0.1, 0.5, 0.2, 0.4], + ] + ) + # Using relevance scores from 0 to 3 + y_true = torch.tensor( + [ + [1, 3, 1, 2, 0], + [3, 1, 0, 2, 1], + ] + ) + return y_pred, y_true + + +def test_binary_ndcg_at_k(setup_data): + y_pred, y_true = setup_data + binary_ndcg = BinaryNDCG_at_k(n_cols=5, k=3) + result = binary_ndcg(y_pred.flatten(), y_true.flatten()) + expected = np.array(0.5719) + np.testing.assert_almost_equal(result, expected, decimal=4) + + +def test_map_at_k(setup_data): + y_pred, y_true = setup_data + map_at_k = MAP_at_k(n_cols=5, k=3) + result = map_at_k(y_pred.flatten(), y_true.flatten()) + expected = np.array(0.4166) + np.testing.assert_almost_equal(result, expected, decimal=4) + + +def test_hit_ratio_at_k(setup_data): + y_pred, y_true = setup_data + hr_at_k = HitRatio_at_k(n_cols=5, k=3) + result = hr_at_k(y_pred.flatten(), y_true.flatten()) + expected = np.array(1.0) + np.testing.assert_almost_equal(result, expected, decimal=4) + + +def test_precision_at_k(setup_data): + y_pred, y_true = setup_data + prec_at_k = Precision_at_k(n_cols=5, k=3) + result = prec_at_k(y_pred.flatten(), y_true.flatten()) + expected = np.array(0.5) + np.testing.assert_almost_equal(result, expected, decimal=4) + + +def test_recall_at_k(setup_data): + y_pred, y_true = setup_data + rec_at_k = Recall_at_k(n_cols=5, k=3) + result = rec_at_k(y_pred.flatten(), y_true.flatten()) + expected = np.array(0.5833) + np.testing.assert_almost_equal(result, expected, decimal=4) + + +def test_edge_cases_all_relevant_items(): + # Test with all relevant items + y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5]) + y_true = torch.tensor([1, 1, 1, 1, 1]) + + ndcg = NDCG_at_k(n_cols=5, k=3) + assert ndcg(y_pred, y_true) == 1.0 + + binary_ndcg = BinaryNDCG_at_k(n_cols=5, k=3) + assert binary_ndcg(y_pred, y_true) == 1.0 + + map_at_k = MAP_at_k(n_cols=5, k=3) + assert np.isclose(map_at_k(y_pred, y_true), 0.6) + + hr_at_k = HitRatio_at_k(n_cols=5, k=3) + assert hr_at_k(y_pred, y_true) == 1.0 + + prec_at_k = Precision_at_k(n_cols=5, k=3) + assert prec_at_k(y_pred, y_true) == 1.0 + + rec_at_k = Recall_at_k(n_cols=5, k=3) + assert np.isclose(rec_at_k(y_pred, y_true), 0.6) + + +def test_edge_cases_no_relevant_items(): + + # Test with no relevant items + y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5]) + y_true = torch.tensor([0, 0, 0, 0, 0]) + + ndcg = NDCG_at_k(n_cols=5, k=3) + assert ndcg(y_pred, y_true) == 0.0 + + binary_ndcg = BinaryNDCG_at_k(n_cols=5, k=3) + assert binary_ndcg(y_pred, y_true) == 0.0 + + map_at_k = MAP_at_k(n_cols=5, k=3) + assert map_at_k(y_pred, y_true) == 0.0 + + hr_at_k = HitRatio_at_k(n_cols=5, k=3) + assert hr_at_k(y_pred, y_true) == 0.0 + + prec_at_k = Precision_at_k(n_cols=5, k=3) + assert prec_at_k(y_pred, y_true) == 0.0 + + rec_at_k = Recall_at_k(n_cols=5, k=3) + assert rec_at_k(y_pred, y_true) == 0.0 + + +def test_k_greater_than_n_cols(): + with pytest.raises(ValueError): + NDCG_at_k(n_cols=5, k=10) + with pytest.raises(ValueError): + BinaryNDCG_at_k(n_cols=5, k=10) + with pytest.raises(ValueError): + MAP_at_k(n_cols=5, k=10) + with pytest.raises(ValueError): + HitRatio_at_k(n_cols=5, k=10) + with pytest.raises(ValueError): + Precision_at_k(n_cols=5, k=10) + with pytest.raises(ValueError): + Recall_at_k(n_cols=5, k=10) + + +def test_ndcg_at_k(setup_data_ndcg): + y_pred, y_true = setup_data_ndcg + ndcg = NDCG_at_k(n_cols=5, k=3) + result = ndcg(y_pred.flatten(), y_true.flatten()) + expected = np.array(0.7198) + np.testing.assert_almost_equal(result, expected, decimal=4) + + +def test_ndcg_at_k_edge_cases(): + # Test with non-decreasing ranking + y_pred = torch.tensor([0.5, 0.6, 0.7, 0.8, 0.9]) + y_true = torch.tensor([0, 1, 2, 3, 4]) + + ndcg = NDCG_at_k(n_cols=5, k=3) + result = ndcg(y_pred, y_true) + assert result == 1.0 + + # Test with non-increasing ranking + ndcg = NDCG_at_k(n_cols=5, k=3) + y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5]) + result = ndcg(y_pred, y_true) + expected = np.array(0.1019) + np.testing.assert_almost_equal(result, expected, decimal=4) + + # Test with all zero relevance + ndcg = NDCG_at_k(n_cols=5, k=3) + y_true = torch.tensor([0, 0, 0, 0, 0]) + result = ndcg(y_pred, y_true) + assert result == 0.0 + + +def test_ndcg_at_k_k_greater_than_n_cols(): + with pytest.raises(ValueError): + NDCG_at_k(n_cols=5, k=10) From d32397a36c2049ba0817f477bb1a9f0a000f4f39 Mon Sep 17 00:00:00 2001 From: Javier Date: Sat, 19 Oct 2024 16:15:48 +0100 Subject: [PATCH 09/24] Added explanation to the metrics and a bit of refactoring --- pytorch_widedeep/metrics.py | 76 ++++---- pytorch_widedeep/training/_base_trainer.py | 35 ++-- pytorch_widedeep/training/trainer.py | 192 ++++++++++++++------- tests/test_metrics/test_ranking_metrics.py | 42 +++++ 4 files changed, 224 insertions(+), 121 deletions(-) diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index e4b871ed..5286952b 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -461,13 +461,13 @@ def reset(self): def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: device = y_pred.device - y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) - batch_size = y_true_reshaped.shape[0] + batch_size = y_true_2d.shape[0] - _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) - top_k_relevance = y_true_reshaped.gather(1, top_k_indices) + _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) + top_k_relevance = y_true_2d.gather(1, top_k_indices) discounts = 1.0 / torch.log2( torch.arange(2, top_k_relevance.shape[1] + 2, device=device) ) @@ -475,7 +475,7 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: dcg = (torch.pow(2, top_k_relevance) - 1) * discounts.unsqueeze(0) dcg = dcg.sum(dim=1) - sorted_relevance, _ = torch.sort(y_true_reshaped, dim=1, descending=True) + sorted_relevance, _ = torch.sort(y_true_2d, dim=1, descending=True) ideal_relevance = sorted_relevance[:, : self.k] idcg = (torch.pow(2, ideal_relevance) - 1) * discounts[ @@ -544,13 +544,13 @@ def reset(self): def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: device = y_pred.device - y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) - batch_size = y_pred_reshaped.shape[0] + batch_size = y_pred_2d.shape[0] - _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) - top_k_mask = torch.zeros_like(y_pred_reshaped, dtype=torch.bool).scatter_( + _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) + top_k_mask = torch.zeros_like(y_pred_2d, dtype=torch.bool).scatter_( 1, top_k_indices, 1 ) @@ -561,12 +561,12 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: discounts = torch.zeros_like(top_k_mask, dtype=torch.float) discounts[top_k_mask] = expanded_discounts - dcg = (y_true_reshaped * top_k_mask * discounts).sum(dim=1) + dcg = (y_true_2d * top_k_mask * discounts).sum(dim=1) n_relevant = torch.minimum( - y_true_reshaped.sum(dim=1), torch.tensor(self.k, device=device) + y_true_2d.sum(dim=1), torch.tensor(self.k, device=device) ).int() ideal_discounts = 1.0 / torch.log2( - torch.arange(2, y_true_reshaped.shape[1] + 2, device=device).float() + torch.arange(2, y_true_2d.shape[1] + 2, device=device).float() ) idcg = torch.zeros(batch_size, device=device) @@ -626,18 +626,18 @@ def reset(self): def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) - batch_size = y_pred_reshaped.shape[0] - _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) - batch_relevance = y_true_reshaped.gather(1, top_k_indices) + batch_size = y_pred_2d.shape[0] + _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) + batch_relevance = y_true_2d.gather(1, top_k_indices) cumsum_relevance = torch.cumsum(batch_relevance, dim=1) precision_at_i = cumsum_relevance / torch.arange( - 1, self.k + 1, device=y_pred_reshaped.device + 1, self.k + 1, device=y_pred_2d.device ).float().unsqueeze(0) avg_precision = (precision_at_i * batch_relevance).sum(dim=1) / torch.clamp( - y_true_reshaped.sum(dim=1), min=1 + y_true_2d.sum(dim=1), min=1 ) self.sum_avg_precision += avg_precision.sum().item() self.count += batch_size @@ -688,11 +688,11 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) - batch_size = y_pred_reshaped.shape[0] - _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) - batch_relevance = y_true_reshaped.gather(1, top_k_indices) + y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + batch_size = y_pred_2d.shape[0] + _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) + batch_relevance = y_true_2d.gather(1, top_k_indices) hit = (batch_relevance.sum(dim=1) > 0).float() self.sum_hr += hit.sum().item() self.count += batch_size @@ -744,11 +744,11 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) - batch_size = y_pred_reshaped.shape[0] - _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) - batch_relevance = y_true_reshaped.gather(1, top_k_indices) + y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + batch_size = y_pred_2d.shape[0] + _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) + batch_relevance = y_true_2d.gather(1, top_k_indices) precision = batch_relevance.sum(dim=1) / self.k self.sum_precision += precision.sum().item() self.count += batch_size @@ -793,14 +793,12 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols) - batch_size = y_pred_reshaped.shape[0] - _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1) - batch_relevance = y_true_reshaped.gather(1, top_k_indices) - recall = batch_relevance.sum(dim=1) / torch.clamp( - y_true_reshaped.sum(dim=1), min=1 - ) + y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + batch_size = y_pred_2d.shape[0] + _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) + batch_relevance = y_true_2d.gather(1, top_k_indices) + recall = batch_relevance.sum(dim=1) / torch.clamp(y_true_2d.sum(dim=1), min=1) self.sum_recall += recall.sum().item() self.count += batch_size return np.array(self.sum_recall / max(self.count, 1)) diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py index 21d8985b..bf2c3875 100644 --- a/pytorch_widedeep/training/_base_trainer.py +++ b/pytorch_widedeep/training/_base_trainer.py @@ -302,27 +302,32 @@ def _set_callbacks_and_metrics( self.callbacks.append(callback) if metrics is not None: self.metric = MultipleMetrics(metrics) - # assert that no ranking metric is used during training assertion - # here so all metrics are instanciated - assert not any( - [isinstance(m, RankingMetrics) for m in self.metric._metrics] # type: ignore[arg-type, misc] - ), "Currently, ranking metrics are not supported during training" - + if ( + any( + [isinstance(m, RankingMetrics) for m in self.metric._metrics] # type: ignore[arg-type, misc] + ) + and self.verbose + ): + UserWarning( + "There are ranking metrics in the 'metrics' list. The implementation " + "in this library requires that all query or user ids must have the " + "same number of entries or items." + ) self.callbacks += [MetricCallback(self.metric)] else: self.metric = None if eval_metrics is not None: self.eval_metric = MultipleMetrics(eval_metrics) - # assert that if any of the metrics is a ranking metric, all metrics - # must be ranking metrics - if any( - [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc] - ): - assert all( + if ( + any( [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc] - ), ( - "All eval metrics must be ranking metrics if any of the eval" - " metrics is a ranking metric" + ) + and self.verbose + ): + UserWarning( + "There are ranking metrics in the 'eval_metric' list. The implementation " + "in this library requires that all query or user ids must have the " + "same number of entries or items." ) self.callbacks += [MetricCallback(self.eval_metric)] else: diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index cff69bfd..7400409b 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -12,8 +12,10 @@ from pytorch_widedeep.losses import ZILNLoss from pytorch_widedeep.metrics import Metric from pytorch_widedeep.wdtypes import ( + Any, Dict, List, + Tuple, Union, Tensor, Literal, @@ -268,6 +270,8 @@ def __init__( ) @alias("finetune", ["warmup"]) + @alias("train_dataloader", ["custom_dataloader"]) + @alias("eval_dataloader", ["custom_eval_dataloader"]) def fit( # noqa: C901 self, X_wide: Optional[np.ndarray] = None, @@ -281,7 +285,8 @@ def fit( # noqa: C901 n_epochs: int = 1, validation_freq: int = 1, batch_size: int = 32, - custom_dataloader: Optional[DataLoader] = None, + train_dataloader: Optional[DataLoader] = None, + eval_dataloader: Optional[DataLoader] = None, feature_importance_sample_size: Optional[int] = None, finetune: bool = False, **kwargs, @@ -418,11 +423,8 @@ def fit( # noqa: C901 [Examples](https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples) folder in the repo """ - dataloader_args, finetune_args = self._extract_kwargs(kwargs) - self.batch_size = batch_size - train_set, eval_set = wd_train_val_split( self.seed, self.method, # type: ignore @@ -436,35 +438,18 @@ def fit( # noqa: C901 target, self.transforms, ) - if custom_dataloader is not None: - # make sure is callable (and HAS to be an subclass of DataLoader) - assert isinstance(custom_dataloader, type) - train_loader = custom_dataloader( # type: ignore[misc] - dataset=train_set, - batch_size=batch_size, - num_workers=self.num_workers, - **dataloader_args, - ) - else: - train_loader = DataLoader( - dataset=train_set, - batch_size=batch_size, - num_workers=self.num_workers, - **dataloader_args, - ) - train_steps = len(train_loader) - if eval_set is not None: - eval_loader = DataLoader( - dataset=eval_set, - batch_size=batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - eval_steps = len(eval_loader) + train_loader, eval_loader = self._get_dataloaders( + train_dataloader, + eval_dataloader, + train_set, + eval_set, + batch_size, + dataloader_args, + ) if finetune: self.with_finetuning: bool = True - self._finetune(train_loader, **finetune_args) + self._do_finetune(train_loader, **finetune_args) if self.verbose: print( "Fine-tuning (or warmup) of individual components completed. " @@ -474,54 +459,32 @@ def fit( # noqa: C901 self.with_finetuning = False self.callback_container.on_train_begin( - {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs} + { + "batch_size": batch_size, + "train_steps": len(train_loader), + "n_epochs": n_epochs, + } ) for epoch in range(n_epochs): - epoch_logs: Dict[str, float] = {} - self.callback_container.on_epoch_begin(epoch, logs=epoch_logs) - - self.train_running_loss = 0.0 - with trange(train_steps, disable=self.verbose != 1) as t: - for batch_idx, (data, targett) in zip(t, train_loader): - t.set_description("epoch %i" % (epoch + 1)) - train_score, train_loss = self._train_step(data, targett, batch_idx) - print_loss_and_metric(t, train_loss, train_score) - self.callback_container.on_batch_end(batch=batch_idx) - epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train") - - on_epoch_end_metric = None - if eval_set is not None and epoch % validation_freq == ( + epoch_logs = self._train_epoch(train_loader, epoch) + if eval_loader is not None and epoch % validation_freq == ( validation_freq - 1 ): - self.callback_container.on_eval_begin() - self.valid_running_loss = 0.0 - with trange(eval_steps, disable=self.verbose != 1) as v: - for i, (data, targett) in zip(v, eval_loader): - v.set_description("valid") - val_score, val_loss = self._eval_step(data, targett, i) - print_loss_and_metric(v, val_loss, val_score) - epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val") - - if self.reducelronplateau: - if self.reducelronplateau_criterion == "loss": - on_epoch_end_metric = val_loss - else: - on_epoch_end_metric = val_score[ - self.reducelronplateau_criterion - ] + epoch_logs, on_epoch_end_metric = self._eval_epoch( + eval_loader, epoch_logs + ) else: + on_epoch_end_metric = None if self.reducelronplateau: raise NotImplementedError( "ReduceLROnPlateau scheduler can be used only with validation data." ) self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric) - if self.early_stop: # self.callback_container.on_train_end(epoch_logs) break self.callback_container.on_train_end(epoch_logs) - if feature_importance_sample_size is not None: self.feature_importance = FeatureImportance( self.device, feature_importance_sample_size @@ -817,9 +780,55 @@ def save( with open(Path(path) / "feature_importance.json", "w") as fi: json.dump(self.feature_importance, fi) + def _get_dataloaders( + self, + train_dataloader: Optional[DataLoader], + eval_dataloader: Optional[DataLoader], + train_set: WideDeepDataset, + eval_set: Optional[WideDeepDataset], + batch_size: int, + dataloader_args: Dict[str, Any], + ) -> Tuple[DataLoader, Optional[DataLoader]]: + if train_dataloader is not None: + # make sure is callable (and HAS to be an subclass of DataLoader) + assert isinstance(train_dataloader, type) + train_loader = train_dataloader( # type: ignore[misc] + dataset=train_set, + batch_size=batch_size, + num_workers=self.num_workers, + **dataloader_args, + ) + else: + train_loader = DataLoader( + dataset=train_set, + batch_size=batch_size, + num_workers=self.num_workers, + **dataloader_args, + ) + + eval_loader = None + if eval_set is not None: + if eval_dataloader is not None: + assert isinstance(eval_dataloader, type) + eval_loader = eval_dataloader( # type: ignore[misc] + dataset=eval_set, + batch_size=batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + else: + eval_loader = DataLoader( + dataset=eval_set, + batch_size=batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + return train_loader, eval_loader + @alias("n_epochs", ["finetune_epochs", "warmup_epochs"]) @alias("max_lr", ["finetune_max_lr", "warmup_max_lr"]) - def _finetune( + def _do_finetune( self, loader: DataLoader, n_epochs: int = 5, @@ -905,6 +914,27 @@ def _finetune( self.model.deepimage, "deepimage", loader, n_epochs, max_lr ) + def _train_epoch( + self, + train_loader: DataLoader, + epoch: int, + ): + epoch_logs: Dict[str, float] = {} + self.callback_container.on_epoch_begin(epoch, logs=epoch_logs) + self.train_running_loss = 0.0 + + train_steps = len(train_loader) + with trange(train_steps, disable=self.verbose != 1) as t: + for batch_idx, (data, targett) in zip(t, train_loader): + t.set_description("epoch %i" % (epoch + 1)) + train_score, train_loss = self._train_step(data, targett, batch_idx) + print_loss_and_metric(t, train_loss, train_score) + self.callback_container.on_batch_end(batch=batch_idx) + + epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train") + + return epoch_logs + def _train_step( self, data: Dict[str, Union[Tensor, List[Tensor]]], @@ -946,6 +976,33 @@ def _train_step( return score, avg_loss + def _eval_epoch( + self, + eval_loader: DataLoader, + epoch_logs: Dict[str, float], + ) -> Tuple[Dict[str, float], Optional[float]]: + self.callback_container.on_eval_begin() + self.valid_running_loss = 0.0 + + eval_steps = len(eval_loader) + with trange(eval_steps, disable=self.verbose != 1) as v: + for i, (data, targett) in zip(v, eval_loader): + v.set_description("valid") + val_score, val_loss = self._eval_step(data, targett, i) + print_loss_and_metric(v, val_loss, val_score) + + epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val") + + if self.reducelronplateau: + if self.reducelronplateau_criterion == "loss": + on_epoch_end_metric = val_loss + else: + on_epoch_end_metric = val_score[self.reducelronplateau_criterion] + else: + on_epoch_end_metric = None + + return epoch_logs, on_epoch_end_metric + def _eval_step( self, data: Dict[str, Union[Tensor, List[Tensor]]], @@ -988,11 +1045,11 @@ def _get_score( score = None metric = None - if hasattr(self, "metric") and not hasattr(self, "eval_metric"): + if self.metric and not self.eval_metric: metric = self.metric - elif hasattr(self, "metric") and hasattr(self, "eval_metric"): + elif self.metric and self.eval_metric: metric = self.metric if is_train else self.eval_metric - elif not hasattr(self, "metric") and hasattr(self, "eval_metric"): + elif not self.metric and self.eval_metric: metric = None if is_train else self.eval_metric if metric is not None: @@ -1097,6 +1154,7 @@ def _predict( # type: ignore[override, return] # noqa: C901 @staticmethod def _predict_ziln(preds: Tensor) -> Tensor: + # Legacy implementation. It will be removed in future versions """Calculates predicted mean of zero inflated lognormal logits. Adjusted implementaion of `code diff --git a/tests/test_metrics/test_ranking_metrics.py b/tests/test_metrics/test_ranking_metrics.py index b08ecb4a..b221658e 100644 --- a/tests/test_metrics/test_ranking_metrics.py +++ b/tests/test_metrics/test_ranking_metrics.py @@ -55,6 +55,15 @@ def test_binary_ndcg_at_k(setup_data): expected = np.array(0.5719) np.testing.assert_almost_equal(result, expected, decimal=4) + # Explaining the expected value (@k=3): + # DCG_row1 = 0*1 + 1*0.6309 + 1*0.5 = 1.1309 -> (0.7 -> 0, 0.9 -> 1, 0.8 -> 1) + # DCG_row2 = 1*1 + 0*0.6309 + 0*0.5 = 1.0 -> (0.3 -> 1, 0.5 -> 0, 0.4 -> 0) + # IDCG_row1 = 1*1 + 1*0.6309 + 1*0.5 = 2.1309 (3 relevant items) + # IDCG_row2 = 1*1 + 1*0.6309 = 1.6309 (2 relevant items) + # NDCG_row1 = DCG_row1 / IDCG_row1 = 1.1309 / 2.1309 = 0.5309 + # NDCG_row2 = DCG_row2 / IDCG_row2 = 1.0 / 1.6309 = 0.6129 + # binary_NDCG = (NDCG_row1 + NDCG_row2) / 2 = (0.5309 + 0.6129) / 2 = 0.5719 + def test_map_at_k(setup_data): y_pred, y_true = setup_data @@ -63,6 +72,12 @@ def test_map_at_k(setup_data): expected = np.array(0.4166) np.testing.assert_almost_equal(result, expected, decimal=4) + # Explaining the expected value (@k=3): + # batch relevance for row top 3 predictions: [[1, 1, 0], [0, 0, 1]] + # AP_row1 = (1/1 + 2/2 + 0/3) / 3 = 0.6666 -> [(1/1) * 1 + (2/2) * 1 + (2/3) * 0] / 3 + # AP_row2 = (0/1 + 0/2 + 1/3) / 2 = 0.1666 -> [(0/1) * 0 + (0/2) * 0 + (1/3) * 1] / 2 + # MAP = (AP_row1 + AP_row2) / 2 = (0.6666 + 0.1666) / 2 = 0.4166 + def test_hit_ratio_at_k(setup_data): y_pred, y_true = setup_data @@ -79,6 +94,14 @@ def test_precision_at_k(setup_data): expected = np.array(0.5) np.testing.assert_almost_equal(result, expected, decimal=4) + # Explaining the expected value (@k=3): + # batch relevance for row top 3 predictions: [[1, 1, 0], [0, 0, 1]] + # Precision_row1 = 2 / 3 = 0.666 + # Precision_row2 = 1 / 3 = 0.333 + # Precision = (Precision_row1 + Precision_row2) / 2 = (0.666 + 0.333) / 2 = 0.5 + # caveat: if k is higher than the number of relevant items for a given row + # this metric will never be 1.0 + def test_recall_at_k(setup_data): y_pred, y_true = setup_data @@ -87,6 +110,12 @@ def test_recall_at_k(setup_data): expected = np.array(0.5833) np.testing.assert_almost_equal(result, expected, decimal=4) + # Explaining the expected value (@k=3): + # batch relevance for row top 3 predictions: [[1, 1, 0], [0, 0, 1]] + # Recall_row1 = 2 / 3 = 0.666 + # Recall_row2 = 1 / 2 = 0.5 + # Recall = (Recall_row1 + Recall_row2) / 2 = (0.666 + 0.5) / 2 = 0.5833 + def test_edge_cases_all_relevant_items(): # Test with all relevant items @@ -159,6 +188,19 @@ def test_ndcg_at_k(setup_data_ndcg): expected = np.array(0.7198) np.testing.assert_almost_equal(result, expected, decimal=4) + # Explaining the expected value (@k=3): + # top 3 predictions for row 1: [0.9, 0.8, 0.7] -> [3, 2, 1] + # top 3 predictions for row 2: [0.5, 0.4, 0.3] -> [0, 1, 3] + # DCG = sum[(2^rel - 1) / log2(rank + 1)] + # DCG_row1 = 2^3 - 1 / log2(1+1) + 2^2 - 1 / log2(2+1) + 2^1 - 1 / log2(3+1) = 7.0 + 1.8928 + 0.5000 = 9.3928 + # DCG_row2 = 2^0 - 1 / log2(1+1) + 2^1 - 1 / log2(2+1) + 2^3 - 1 / log2(3+1) = 0.0 + 0.6309 + 3.5 = 4.1309 + # IDCG = sum[(2^rel - 1) / log2(rank + 1)] for the ideal ranking at k=3 [[3, 2, 1], [3, 2, 1]] + # IDCG_row1 = 2^3 - 1 / log2(1+1) + 2^2 - 1 / log2(2+1) + 2^1 - 1 / log2(3+1) = 7.0 + 1.8928 + 0.5000 = 9.3928 + # IDCG_row2 = IDCG_row1 = 9.3928 + # NDGC_row1 = DCG_row1 / IDCG_row1 = 9.3928 / 9.3928 = 1.0 + # NDGC_row2 = DCG_row2 / IDCG_row2 = 4.1309 / 9.3928 = 0.4398 + # NDCG = (NDGC_row1 + NDGC_row2) / 2 = (1.0 + 0.4398) / 2 = 0.7198 + def test_ndcg_at_k_edge_cases(): # Test with non-decreasing ranking From f091194bf69be95558be667ea337f2ff518f1bda Mon Sep 17 00:00:00 2001 From: Javier Date: Sat, 19 Oct 2024 23:11:08 +0100 Subject: [PATCH 10/24] First integration test run. I will have to adjust all examples containing the class --- pytest.ini | 2 + pytorch_widedeep/dataloaders.py | 79 +++++++-- pytorch_widedeep/metrics.py | 66 +++++--- pytorch_widedeep/training/trainer.py | 67 +++----- .../test_ranking_metrics_integration.py | 153 ++++++++++++++++++ 5 files changed, 285 insertions(+), 82 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/test_rec/test_ranking_metrics_integration.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..c1fa8785 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -p no:warnings \ No newline at end of file diff --git a/pytorch_widedeep/dataloaders.py b/pytorch_widedeep/dataloaders.py index 3c17bc5c..db53ace0 100644 --- a/pytorch_widedeep/dataloaders.py +++ b/pytorch_widedeep/dataloaders.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import numpy as np from torch.utils.data import DataLoader, WeightedRandomSampler @@ -6,6 +6,45 @@ from pytorch_widedeep.training._wd_dataset import WideDeepDataset +class DatasetAlreadySetError(Exception): + """Exception raised when attempting to set a dataset that has already been set.""" + + def __init__(self, message="Dataset has already been set and cannot be changed."): + self.message = message + super().__init__(self.message) + + +class CustomDataLoader(DataLoader): + r""" + Wrapper around the `torch.utils.data.DataLoader` class that allows + to set the dataset after the class has been instantiated. + """ + + def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs): + self.dataset_set = dataset is not None + if self.dataset_set: + super().__init__(dataset, *args, **kwargs) + else: + self.args = args + self.kwargs = kwargs + + def set_dataset(self, dataset: WideDeepDataset): + if self.dataset_set: + raise DatasetAlreadySetError() + + self.dataset_set = True + super().__init__(dataset, *self.args, **self.kwargs) + + def __iter__(self): + if not self.dataset_set: + raise ValueError( + "Dataset has not been set. Use set_dataset method to set a dataset." + ) + return super().__iter__() + + +# From here on is legacy code and there are better ways to do it. It will be +# removed in the next version. def get_class_weights(dataset: WideDeepDataset) -> Tuple[np.ndarray, int, int]: """Helper function to get weights of classes in the imbalanced dataset. @@ -37,7 +76,7 @@ def get_class_weights(dataset: WideDeepDataset) -> Tuple[np.ndarray, int, int]: return weights, minor_class_count, num_classes -class DataLoaderImbalanced(DataLoader): +class DataLoaderImbalanced(CustomDataLoader): r"""Class to load and shuffle batches with adjusted weights for imbalanced datasets. If the classes do not begin from 0 remapping is necessary. See [here](https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab). @@ -46,13 +85,11 @@ class DataLoaderImbalanced(DataLoader): ---------- dataset: `WideDeepDataset` see `pytorch_widedeep.training._wd_dataset` - batch_size: int - size of batch - num_workers: int - number of workers Other Parameters ---------------- + *args: Any + Positional arguments to be passed to the parent CustomDataLoader. **kwargs: Dict This can include any parameter that can be passed to the _'standard'_ pytorch @@ -69,23 +106,31 @@ class DataLoaderImbalanced(DataLoader): $$ """ - def __init__( - self, dataset: WideDeepDataset, batch_size: int, num_workers: int, **kwargs - ): + def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + if dataset is not None: + self._setup_sampler(dataset) + super().__init__(dataset, *args, sampler=self.sampler, **kwargs) + else: + super().__init__() + + def set_dataset(self, dataset: WideDeepDataset): + sampler = self._setup_sampler(dataset) + # update the kwargs with the new sampler + self.kwargs["sampler"] = sampler + super().set_dataset(dataset) + + def _setup_sampler(self, dataset: WideDeepDataset) -> WeightedRandomSampler: assert dataset.Y is not None, ( "The 'dataset' instance of WideDeepDataset must contain a " "target array 'Y'" ) - if "oversample_mul" in kwargs: - oversample_mul = kwargs["oversample_mul"] - del kwargs["oversample_mul"] - else: - oversample_mul = 1 + oversample_mul = self.kwargs.pop("oversample_mul", 1) weights, minor_cls_cnt, num_clss = get_class_weights(dataset) num_samples = int(minor_cls_cnt * num_clss * oversample_mul) samples_weight = list(np.array([weights[i] for i in dataset.Y])) sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True) - super().__init__( - dataset, batch_size, num_workers=num_workers, sampler=sampler, **kwargs - ) + return sampler diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index 5286952b..c708ff9b 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -397,15 +397,25 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: return np.array((1 - (self.numerator / self.denominator))) -def reshape_1d_to_2d(tensor: Tensor, n_columns: int) -> Tensor: - if tensor.dim() != 1: - raise ValueError("Input tensor must be 1-dimensional") - if tensor.size(0) % n_columns != 0: +def reshape_to_2d(tensor: Tensor, n_columns: int) -> Tensor: + if tensor.dim() == 1: + if tensor.size(0) % n_columns != 0: + raise ValueError( + f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})" + ) + n_rows = tensor.size(0) // n_columns + return tensor.reshape(n_rows, n_columns) + elif tensor.dim() == 2 and tensor.size(1) == 1: + if tensor.size(0) % n_columns != 0: + raise ValueError( + f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})" + ) + n_rows = tensor.size(0) // n_columns + return tensor.reshape(n_rows, n_columns) + else: raise ValueError( - f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})" + "Input tensor must be 1-dimensional or 2-dimensional with one column" ) - n_rows = tensor.size(0) // n_columns - return tensor.reshape(n_rows, n_columns) class NDCG_at_k(Metric): @@ -461,8 +471,8 @@ def reset(self): def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: device = y_pred.device - y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_to_2d(y_true, self.n_cols) batch_size = y_true_2d.shape[0] @@ -544,8 +554,8 @@ def reset(self): def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: device = y_pred.device - y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_to_2d(y_true, self.n_cols) batch_size = y_pred_2d.shape[0] @@ -607,6 +617,9 @@ class MAP_at_k(Metric): >>> score = map_at_k(y_pred, y_true) """ + @alias( + "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"] + ) def __init__(self, n_cols: int = 10, k: Optional[int] = None): super(MAP_at_k, self).__init__() @@ -626,8 +639,8 @@ def reset(self): def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_to_2d(y_true, self.n_cols) batch_size = y_pred_2d.shape[0] _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) @@ -670,6 +683,9 @@ class HitRatio_at_k(Metric): >>> score = hr_at_k(y_pred, y_true """ + @alias( + "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"] + ) def __init__(self, n_cols: int = 10, k: Optional[int] = None): super(HitRatio_at_k, self).__init__() @@ -688,8 +704,8 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_to_2d(y_true, self.n_cols) batch_size = y_pred_2d.shape[0] _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) batch_relevance = y_true_2d.gather(1, top_k_indices) @@ -726,6 +742,9 @@ class Precision_at_k(Metric): >>> score = prec_at_k(y_pred, y_true) """ + @alias( + "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"] + ) def __init__(self, n_cols: int = 10, k: Optional[int] = None): super(Precision_at_k, self).__init__() @@ -744,8 +763,8 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_to_2d(y_true, self.n_cols) batch_size = y_pred_2d.shape[0] _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) batch_relevance = y_true_2d.gather(1, top_k_indices) @@ -761,6 +780,12 @@ class Recall_at_k(Metric): Parameters ---------- + n_cols: int, default = 10 + Number of columns in the input tensors. This parameter is neccessary + because the input tensors are reshaped to 2D tensors. n_cols is the + number of columns in the reshaped tensor. Alias for this parameter + are: 'n_items', 'n_items_per_query', + 'n_items_per_id', 'n_items_per_user' k: int, default = 10 Number of top items to consider. @@ -775,6 +800,9 @@ class Recall_at_k(Metric): >>> score = rec_at_k(y_pred, y_true) """ + @alias( + "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"] + ) def __init__(self, n_cols: int = 10, k: Optional[int] = None): super(Recall_at_k, self).__init__() @@ -793,8 +821,8 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols) - y_true_2d = reshape_1d_to_2d(y_true, self.n_cols) + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) + y_true_2d = reshape_to_2d(y_true, self.n_cols) batch_size = y_pred_2d.shape[0] _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1) batch_relevance = y_true_2d.gather(1, top_k_indices) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 7400409b..bd871f3b 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -26,6 +26,7 @@ LRScheduler, ) from pytorch_widedeep.callbacks import Callback +from pytorch_widedeep.dataloaders import CustomDataLoader, DatasetAlreadySetError from pytorch_widedeep.initializers import Initializer from pytorch_widedeep.training._finetune import FineTune from pytorch_widedeep.utils.general_utils import alias, to_device @@ -285,8 +286,8 @@ def fit( # noqa: C901 n_epochs: int = 1, validation_freq: int = 1, batch_size: int = 32, - train_dataloader: Optional[DataLoader] = None, - eval_dataloader: Optional[DataLoader] = None, + train_dataloader: Optional[CustomDataLoader] = None, + eval_dataloader: Optional[CustomDataLoader] = None, feature_importance_sample_size: Optional[int] = None, finetune: bool = False, **kwargs, @@ -438,13 +439,13 @@ def fit( # noqa: C901 target, self.transforms, ) - train_loader, eval_loader = self._get_dataloaders( - train_dataloader, - eval_dataloader, - train_set, - eval_set, - batch_size, - dataloader_args, + train_loader = self._set_dataloader( + train_dataloader, train_set, batch_size, dataloader_args + ) + eval_loader = ( + self._set_dataloader(eval_dataloader, eval_set, batch_size, dataloader_args) + if eval_set is not None + else None ) if finetune: @@ -780,51 +781,25 @@ def save( with open(Path(path) / "feature_importance.json", "w") as fi: json.dump(self.feature_importance, fi) - def _get_dataloaders( + def _set_dataloader( self, - train_dataloader: Optional[DataLoader], - eval_dataloader: Optional[DataLoader], - train_set: WideDeepDataset, - eval_set: Optional[WideDeepDataset], + dataloader: Optional[CustomDataLoader], + dataset: WideDeepDataset, batch_size: int, dataloader_args: Dict[str, Any], - ) -> Tuple[DataLoader, Optional[DataLoader]]: - if train_dataloader is not None: - # make sure is callable (and HAS to be an subclass of DataLoader) - assert isinstance(train_dataloader, type) - train_loader = train_dataloader( # type: ignore[misc] - dataset=train_set, - batch_size=batch_size, - num_workers=self.num_workers, - **dataloader_args, - ) + ) -> DataLoader | CustomDataLoader: + if dataloader is not None: + dataloader.set_dataset(dataset) + return dataloader else: - train_loader = DataLoader( - dataset=train_set, + # var name 'loader' to avoid reassigment and type errors + loader = DataLoader( + dataset=dataset, batch_size=batch_size, num_workers=self.num_workers, **dataloader_args, ) - - eval_loader = None - if eval_set is not None: - if eval_dataloader is not None: - assert isinstance(eval_dataloader, type) - eval_loader = eval_dataloader( # type: ignore[misc] - dataset=eval_set, - batch_size=batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - else: - eval_loader = DataLoader( - dataset=eval_set, - batch_size=batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - return train_loader, eval_loader + return loader @alias("n_epochs", ["finetune_epochs", "warmup_epochs"]) @alias("max_lr", ["finetune_max_lr", "warmup_max_lr"]) diff --git a/tests/test_rec/test_ranking_metrics_integration.py b/tests/test_rec/test_ranking_metrics_integration.py new file mode 100644 index 00000000..1239ae54 --- /dev/null +++ b/tests/test_rec/test_ranking_metrics_integration.py @@ -0,0 +1,153 @@ +# silence DeprecationWarning +import warnings +from typing import Tuple + +import numpy as np +import pandas as pd +import pytest + +from pytorch_widedeep.models import TabMlp, WideDeep +from pytorch_widedeep.metrics import ( + F1Score, + Accuracy, + MAP_at_k, + NDCG_at_k, + Recall_at_k, + HitRatio_at_k, + Precision_at_k, + BinaryNDCG_at_k, +) +from pytorch_widedeep.training import Trainer +from pytorch_widedeep.preprocessing import TabPreprocessor +from pytorch_widedeep.training._wd_dataset import WideDeepDataset + + +def generate_user_item_interactions( + n_users: int = 10, + n_interactions_per_user: int = 10, + n_items: int = 5, + random_seed: int = 42, +) -> pd.DataFrame: + np.random.seed(random_seed) + + user_ids = np.repeat(range(1, n_users + 1), n_interactions_per_user) + item_ids = np.random.randint(1, n_items + 1, size=n_users * n_interactions_per_user) + user_categories = np.random.choice( + ["A", "B", "C"], size=n_users * n_interactions_per_user + ) + item_categories = np.random.choice( + ["X", "Y", "Z"], size=n_users * n_interactions_per_user + ) + binary_target = np.random.randint(0, 2, size=n_users * n_interactions_per_user) + categorical_target = np.random.randint(0, 3, size=n_users * n_interactions_per_user) + + df = pd.DataFrame( + { + "user_id": user_ids, + "item_id": item_ids, + "user_category": user_categories, + "item_category": item_categories, + "liked": binary_target, + "rating": categorical_target, + } + ) + + return df + + +def split_train_validation( + df: pd.DataFrame, validation_interactions_per_user: int +) -> Tuple[pd.DataFrame, pd.DataFrame]: + grouped = ( + df.groupby("user_id") + .apply(lambda x: x.sample(frac=1, random_state=42)) + .reset_index(drop=True) + ) + + train_df = ( + grouped.groupby("user_id") + .apply(lambda x: x.iloc[:-validation_interactions_per_user]) + .reset_index(drop=True) + ) + val_df = ( + grouped.groupby("user_id") + .apply(lambda x: x.iloc[-validation_interactions_per_user:]) + .reset_index(drop=True) + ) + + return train_df, val_df + + +@pytest.fixture +def user_item_data(request): + validation_interactions_per_user = request.param.get( + "validation_interactions_per_user", 5 + ) + df = generate_user_item_interactions() + train_df, val_df = split_train_validation(df, validation_interactions_per_user) + return train_df, val_df + + +# strategy 1: +# * ranking metrics for both training and validation sets. +# * Same number of items per user in both sets + + +@pytest.mark.parametrize( + "user_item_data", [{"validation_interactions_per_user": 5}], indirect=True +) +@pytest.mark.parametrize( + "metric", + [ + MAP_at_k(n_cols=5, k=3), + NDCG_at_k(n_cols=5, k=3), + Recall_at_k(n_cols=5, k=3), + HitRatio_at_k(n_cols=5, k=3), + Precision_at_k(n_cols=5, k=3), + BinaryNDCG_at_k(n_cols=5, k=3), + ], +) +def test_binary_classification_strategy_1(user_item_data, metric): + + train_df, val_df = user_item_data + + categorical_cols = ["user_id", "item_id", "user_category", "item_category"] + + target_col = "liked" + + tab_preprocessor = TabPreprocessor( + embed_cols=categorical_cols, + for_transformer=False, + ) + + tab_preprocessor.fit(train_df) + + X_train = tab_preprocessor.transform(train_df) + X_val = tab_preprocessor.transform(val_df) + + tab_mlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type] + ) + model = WideDeep(deeptabular=tab_mlp) + + trainer = Trainer(model, objective="binary", metrics=[metric]) + + trainer.fit( + X_train={"X_tab": X_train, "target": train_df[target_col].values}, + X_val={"X_tab": X_val, "target": val_df[target_col].values}, + n_epochs=1, + batch_size=2 * 5, + ) + + # predict on validation, this is just a test... + preds = trainer.predict(X_tab=X_val) + + assert preds.shape[0] == X_val.shape[0] + assert ( + trainer.history is not None + and "train_loss" in trainer.history + and "val_loss" in trainer.history + and f"train_{metric._name}" in trainer.history + and f"val_{metric._name}" in trainer.history + ) From 62db3ca23cc540dfec6de0afba583afd0d441942 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 21 Oct 2024 09:57:14 +0100 Subject: [PATCH 11/24] All test passed. I need to update examples for the Imbalanced Data Loader. Then review all the docs and cleans/rmove the docs folder --- .gitignore | 11 +- .isort.cfg | 4 - .travis.yml | 34 -- code_style.sh | 8 - pytest.ini | 2 - pytorch_widedeep/metrics.py | 10 + pytorch_widedeep/training/_base_trainer.py | 3 +- pytorch_widedeep/training/trainer.py | 2 +- .../training/trainer_from_folder.py | 330 ++--------------- .../test_fit_methods.py | 4 +- .../test_ranking_metrics_integration.py | 339 +++++++++++++++++- 11 files changed, 382 insertions(+), 365 deletions(-) delete mode 100644 .isort.cfg delete mode 100644 .travis.yml delete mode 100755 code_style.sh delete mode 100644 pytest.ini diff --git a/.gitignore b/.gitignore index 06c4cc68..710f3c01 100644 --- a/.gitignore +++ b/.gitignore @@ -60,4 +60,13 @@ checkpoints # wnb wandb/ -wandb_api.key \ No newline at end of file +wandb_api.key + +# pytest +pytest.ini + +# code style +code_style.sh + +# isort +.isort.cfg diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 023dbdb4..00000000 --- a/.isort.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[settings] -profile=black -multi_line_output=3 -length_sort=1 \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 75550745..00000000 --- a/.travis.yml +++ /dev/null @@ -1,34 +0,0 @@ -dist: xenial -language: python -python: - - "3.7.9" - - "3.8" - - "3.9" -matrix: - fast_finish: true - include: - - name: "Code Style (Black/Flake8)" - install: - - pip install black - - pip install flake8 - script: - # Black code style - - black --check --diff pytorch_widedeep tests examples setup.py - # Stop the build if there are Python syntax errors or undefined names - - flake8 . --count --select=E901,E999,F821,F822,F823 --ignore=E266 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --ignore=E203,E266,E501,E722,F401,F403,F405,F811,W503,C901 --statistics - after_success: skip -install: - - pip install --upgrade pip - - pip install pytest-cov - - pip install codecov - - pip install . -script: - - pytest --doctest-modules pytorch_widedeep --cov-report html --cov-report term --disable-pytest-warnings --cov=pytorch_widedeep tests/ - # - pytest tests -after_success: - # Combine coverage reports from Travis jobs - - coverage combine - # Send reports to Codecov. It automatically handles merging coverage results - - codecov \ No newline at end of file diff --git a/code_style.sh b/code_style.sh deleted file mode 100755 index 52eb2d64..00000000 --- a/code_style.sh +++ /dev/null @@ -1,8 +0,0 @@ -# sort imports -isort --quiet . pytorch_widedeep tests examples setup.py -# Black code style -black . pytorch_widedeep tests examples setup.py -# flake8 standards -flake8 . --max-complexity=10 --max-line-length=127 --ignore=E203,E266,E501,E722,E721,F401,F403,F405,W503,C901,F811 -# mypy -mypy pytorch_widedeep --ignore-missing-imports --no-strict-optional \ No newline at end of file diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index c1fa8785..00000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -addopts = -p no:warnings \ No newline at end of file diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index c708ff9b..772f4c57 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -469,8 +469,18 @@ def reset(self): self.count = 0 def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + # NDGC@k is supposed to be used when the output reflects interest + # scores, i.e, could be used in a regression or a multiclass problem. + # If regression y_pred will be a float tensor, if multiclass, y_pred + # will be a float tensor with the output of a softmax activation + # function and we need to turn it into a 1D tensor with the class. + # Finally, for binary problems, please use BinaryNDCG_at_k device = y_pred.device + if y_pred.ndim > 1 and y_pred.size(1) > 1: + # multiclass + y_pred = y_pred.topk(1, 1)[1] + y_pred_2d = reshape_to_2d(y_pred, self.n_cols) y_true_2d = reshape_to_2d(y_true, self.n_cols) diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py index bf2c3875..2f15a3bb 100644 --- a/pytorch_widedeep/training/_base_trainer.py +++ b/pytorch_widedeep/training/_base_trainer.py @@ -63,6 +63,7 @@ def __init__( transforms: Optional[List[Transforms]], callbacks: Optional[List[Callback]], metrics: Optional[Union[List[Metric], List[TorchMetric]]], + eval_metrics: Optional[Union[List[Metric], List[TorchMetric]]], verbose: int, seed: int, **kwargs, @@ -90,7 +91,7 @@ def __init__( self.optimizer = self._set_optimizer(optimizers) self.lr_scheduler = self._set_lr_scheduler(lr_schedulers, **kwargs) self.transforms = self._set_transforms(transforms) - self._set_callbacks_and_metrics(callbacks, metrics) + self._set_callbacks_and_metrics(callbacks, metrics, eval_metrics) @abstractmethod def fit(self, **kwargs): diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index bd871f3b..e259268d 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -26,7 +26,7 @@ LRScheduler, ) from pytorch_widedeep.callbacks import Callback -from pytorch_widedeep.dataloaders import CustomDataLoader, DatasetAlreadySetError +from pytorch_widedeep.dataloaders import CustomDataLoader from pytorch_widedeep.initializers import Initializer from pytorch_widedeep.training._finetune import FineTune from pytorch_widedeep.utils.general_utils import alias, to_device diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py index 1ebc2629..fac5a925 100644 --- a/pytorch_widedeep/training/trainer_from_folder.py +++ b/pytorch_widedeep/training/trainer_from_folder.py @@ -13,7 +13,6 @@ List, Union, Tensor, - Literal, Optional, WideDeep, Optimizer, @@ -22,30 +21,12 @@ ) from pytorch_widedeep.callbacks import Callback from pytorch_widedeep.initializers import Initializer -from pytorch_widedeep.training._finetune import FineTune +from pytorch_widedeep.training.trainer import Trainer from pytorch_widedeep.utils.general_utils import alias, to_device from pytorch_widedeep.training._wd_dataset import WideDeepDataset -from pytorch_widedeep.training._base_trainer import BaseTrainer -from pytorch_widedeep.training._trainer_utils import ( - save_epoch_logs, - print_loss_and_metric, -) - -# Observation 1: I am annoyed by sublime highlighting an override issue with -# the abstractmethods. There is no override issue. The Signature of -# the 'predict' method is compatible with the supertype. Buy for whatever -# issue sublime highlights this as an error (not vscode and is not returned -# as an error when running mypy). I am ignoring it -# There is a lot of code repetition between this class and the 'Trainer' -# class (and in consquence a lot of ignore methods for test coverage). Maybe -# in the future I decided to merge the two of them and offer the ability to -# laod from folder based on the input parameters. For now, I'll leave it like -# this, separated, since it is the easiest and most manageable(i.e. easier to -# debug) implementation - -class TrainerFromFolder(BaseTrainer): +class TrainerFromFolder(Trainer): r"""Class to set the of attributes that will be used during the training process. @@ -187,6 +168,7 @@ class TrainerFromFolder(BaseTrainer): "objective", ["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"], ) + @alias("metrics", ["train_metrics"]) def __init__( self, model: WideDeep, @@ -204,6 +186,7 @@ def __init__( transforms: Optional[List[Transforms]] = None, callbacks: Optional[List[Callback]] = None, metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None, + eval_metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None, verbose: int = 1, seed: int = 1, **kwargs, @@ -218,13 +201,19 @@ def __init__( transforms=transforms, callbacks=callbacks, metrics=metrics, + eval_metrics=eval_metrics, verbose=verbose, seed=seed, **kwargs, ) + if self.method == "multitarget": + raise NotImplementedError( + "Training from folder is not supported for multitarget models" + ) + @alias("finetune", ["warmup"]) - def fit( # noqa: C901 + def fit( # type: ignore[override] # noqa: C901 self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None, @@ -233,68 +222,44 @@ def fit( # noqa: C901 finetune: bool = False, **kwargs, ): - finetune_args = self._extract_kwargs(kwargs) - train_steps = len(train_loader) + # There will never be dataloader_args when using 'TrainingFromFolder' + # as the loaders are passed as arguments + _, finetune_args = self._extract_kwargs(kwargs) if finetune: - self._finetune(train_loader, **finetune_args) + self.with_finetuning: bool = True + self._do_finetune(train_loader, **finetune_args) if self.verbose: print( "Fine-tuning (or warmup) of individual components completed. " "Training the whole model for {} epochs".format(n_epochs) ) + else: + self.with_finetuning = False self.callback_container.on_train_begin( { "batch_size": train_loader.batch_size, - "train_steps": train_steps, + "train_steps": len(train_loader), "n_epochs": n_epochs, } ) for epoch in range(n_epochs): - epoch_logs: Dict[str, float] = {} - self.callback_container.on_epoch_begin(epoch, logs=epoch_logs) - - self.train_running_loss = 0.0 - with trange(train_steps, disable=self.verbose != 1) as t: - for batch_idx, (data, targett) in zip(t, train_loader): - t.set_description("epoch %i" % (epoch + 1)) - train_score, train_loss = self._train_step( - data, targett, batch_idx, epoch - ) - print_loss_and_metric(t, train_loss, train_score) - self.callback_container.on_batch_end(batch=batch_idx) - epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train") - - on_epoch_end_metric = None + epoch_logs = self._train_epoch(train_loader, epoch) if eval_loader is not None and epoch % validation_freq == ( validation_freq - 1 ): - eval_steps = len(eval_loader) - self.callback_container.on_eval_begin() - self.valid_running_loss = 0.0 - with trange(eval_steps, disable=self.verbose != 1) as v: - for i, (data, targett) in zip(v, eval_loader): - v.set_description("valid") - val_score, val_loss = self._eval_step(data, targett, i) - print_loss_and_metric(v, val_loss, val_score) - epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val") - - if self.reducelronplateau: # pragma: no cover - if self.reducelronplateau_criterion == "loss": - on_epoch_end_metric = val_loss - else: - on_epoch_end_metric = val_score[ - self.reducelronplateau_criterion - ] + epoch_logs, on_epoch_end_metric = self._eval_epoch( + eval_loader, epoch_logs + ) else: + on_epoch_end_metric = None if self.reducelronplateau: raise NotImplementedError( "ReduceLROnPlateau scheduler can be used only with validation data." ) self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric) - if self.early_stop: # self.callback_container.on_train_end(epoch_logs) break @@ -327,7 +292,7 @@ def predict( # type: ignore[override, return] preds = np.vstack(preds_l) return np.argmax(preds, 1) - def predict_uncertainty( # type: ignore[return] + def predict_uncertainty( # type: ignore[override, return] self, X_wide: Optional[np.ndarray] = None, X_tab: Optional[np.ndarray] = None, @@ -402,202 +367,7 @@ def predict_proba( # type: ignore[override, return] if self.method == "multiclass": return np.vstack(preds_l) - def save( - self, - path: str, - save_state_dict: bool = False, - save_optimizer: bool = False, - model_filename: str = "wd_model.pt", - ): # pragma: no cover - """ - Parameters - ---------- - path: str - path to the directory where the model and the feature importance - attribute will be saved. - save_state_dict: bool, default = False - Boolean indicating whether to save directly the model - (and optimizer) or the model's (and optimizer's) state - dictionary - save_optimizer: bool, default = False - Boolean indicating whether to save the optimizer - model_filename: str, Optional, default = "wd_model.pt" - filename where the model weights will be store - """ - self._save_history(path) - - self._save_model_and_optimizer( - path, save_state_dict, save_optimizer, model_filename - ) - - @alias("n_epochs", ["finetune_epochs", "warmup_epochs"]) - @alias("max_lr", ["finetune_max_lr", "warmup_max_lr"]) - def _finetune( - self, - loader: DataLoader, - n_epochs: int = 5, - max_lr: float = 0.01, - routine: Literal["howard", "felbo"] = "howard", - deeptabular_gradual: bool = False, - deeptabular_layers: Optional[List[nn.Module]] = None, - deeptabular_max_lr: float = 0.01, - deeptext_gradual: bool = False, - deeptext_layers: Optional[List[nn.Module]] = None, - deeptext_max_lr: float = 0.01, - deepimage_gradual: bool = False, - deepimage_layers: Optional[List[nn.Module]] = None, - deepimage_max_lr: float = 0.01, - ): - r""" - Simple wrap-up to individually fine-tune model components - """ - if self.model.deephead is not None: - raise ValueError( - "Currently warming up is only supported without a fully connected 'DeepHead'" - ) - - finetuner = FineTune(self.loss_fn, self.metric, self.method, self.verbose) # type: ignore[arg-type] - if self.model.wide: - finetuner.finetune_all(self.model.wide, "wide", loader, n_epochs, max_lr) - - if self.model.deeptabular: - if deeptabular_gradual: - assert ( - deeptabular_layers is not None - ), "deeptabular_layers must be passed if deeptabular_gradual=True" - finetuner.finetune_gradual( - self.model.deeptabular, - "deeptabular", - loader, - deeptabular_max_lr, - deeptabular_layers, - routine, - ) - else: - finetuner.finetune_all( - self.model.deeptabular, "deeptabular", loader, n_epochs, max_lr - ) - - if self.model.deeptext: - if deeptext_gradual: - assert ( - deeptext_layers is not None - ), "deeptext_layers must be passed if deeptabular_gradual=True" - finetuner.finetune_gradual( - self.model.deeptext, - "deeptext", - loader, - deeptext_max_lr, - deeptext_layers, - routine, - ) - else: - finetuner.finetune_all( - self.model.deeptext, "deeptext", loader, n_epochs, max_lr - ) - - if self.model.deepimage: - if deepimage_gradual: - assert ( - deepimage_layers is not None - ), "deepimage_layers must be passed if deeptabular_gradual=True" - finetuner.finetune_gradual( - self.model.deepimage, - "deepimage", - loader, - deepimage_max_lr, - deepimage_layers, - routine, - ) - else: - finetuner.finetune_all( - self.model.deepimage, "deepimage", loader, n_epochs, max_lr - ) - - def _train_step( - self, - data: Dict[str, Union[Tensor, List[Tensor]]], - target: Tensor, - batch_idx: int, - epoch: int, - ): - self.model.train() - X: Dict[str, Union[Tensor, List[Tensor]]] = {} - for k, v in data.items(): - if isinstance(v, list): - X[k] = [to_device(i, self.device) for i in v] - else: - X[k] = to_device(v, self.device) - y = ( - target.view(-1, 1).float() - if self.method not in ["multiclass", "qregression"] - else target - ) - y = to_device(y, self.device) - - self.optimizer.zero_grad() - - y_pred = self.model(X) - - if self.model.is_tabnet: # pragma: no cover - loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1] - score = self._get_score(y_pred[0], y) - else: - loss = self.loss_fn(y_pred, y) - score = self._get_score(y_pred, y) - - loss.backward() - self.optimizer.step() - - self.train_running_loss += loss.item() - avg_loss = self.train_running_loss / (batch_idx + 1) - - return score, avg_loss - - def _eval_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): - self.model.eval() - with torch.no_grad(): - X: Dict[str, Union[Tensor, List[Tensor]]] = {} - for k, v in data.items(): - if isinstance(v, list): - X[k] = [to_device(i, self.device) for i in v] - else: - X[k] = to_device(v, self.device) - y = ( - target.view(-1, 1).float() - if self.method not in ["multiclass", "qregression"] - else target - ) - y = to_device(y, self.device) - - y_pred = self.model(X) - if self.model.is_tabnet: # pragma: no cover - loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1] - score = self._get_score(y_pred[0], y) - else: - score = self._get_score(y_pred, y) - loss = self.loss_fn(y_pred, y) - - self.valid_running_loss += loss.item() - avg_loss = self.valid_running_loss / (batch_idx + 1) - - return score, avg_loss - - def _get_score(self, y_pred, y): # pragma: no cover - if self.metric is not None: - if self.method == "regression": - score = self.metric(y_pred, y) - if self.method == "binary": - score = self.metric(torch.sigmoid(y_pred), y) - if self.method == "qregression": - score = self.metric(y_pred, y) - if self.method == "multiclass": - score = self.metric(F.softmax(y_pred, dim=1), y) - return score - else: - return None - - def _predict( # noqa: C901 + def _predict( # type: ignore[override] # noqa: C901 self, X_wide: Optional[np.ndarray] = None, X_tab: Optional[np.ndarray] = None, @@ -606,7 +376,7 @@ def _predict( # noqa: C901 X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None, test_loader: Optional[DataLoader] = None, batch_size: Optional[int] = None, - uncertainty_granularity=1000, + uncertainty_granularity: int = 1000, uncertainty: bool = False, ) -> List: r"""Private method to avoid code repetition in predict and @@ -690,49 +460,3 @@ def _predict( # noqa: C901 preds_l.append(preds) self.model.train() return preds_l - - @staticmethod - def _predict_ziln(preds: Tensor) -> Tensor: # pragma: no cover - """Calculates predicted mean of zero inflated lognormal logits. - - Adjusted implementaion of `code - ` - - Arguments: - preds: [batch_size, 3] tensor of logits. - Returns: - ziln_preds: [batch_size, 1] tensor of predicted mean. - """ - positive_probs = torch.sigmoid(preds[..., :1]) - loc = preds[..., 1:2] - scale = F.softplus(preds[..., 2:]) - ziln_preds = positive_probs * torch.exp(loc + 0.5 * torch.square(scale)) - return ziln_preds - - @staticmethod - def _extract_kwargs(kwargs): - finetune_params = [ - "n_epochs", - "finetune_epochs", - "warmup_epochs", - "max_lr", - "finetune_max_lr", - "warmup_max_lr", - "routine", - "deeptabular_gradual", - "deeptabular_layers", - "deeptabular_max_lr", - "deeptext_gradual", - "deeptext_layers", - "deeptext_max_lr", - "deepimage_gradual", - "deepimage_layers", - "deepimage_max_lr", - ] - - finetune_args = {} - for k, v in kwargs.items(): - if k in finetune_params: - finetune_args[k] = v - - return finetune_args diff --git a/tests/test_model_functioning/test_fit_methods.py b/tests/test_model_functioning/test_fit_methods.py index f873f75d..a1188c02 100644 --- a/tests/test_model_functioning/test_fit_methods.py +++ b/tests/test_model_functioning/test_fit_methods.py @@ -286,12 +286,12 @@ def test_custom_dataloader(): ) model = WideDeep(wide=wide, deeptabular=deeptabular) trainer = Trainer(model, loss="binary", verbose=0) + trainer.fit( X_wide=X_wide, X_tab=X_tab, target=target_binary_imbalanced, - batch_size=16, - custom_dataloader=DataLoaderImbalanced, + custom_dataloader=DataLoaderImbalanced(batch_size=16), ) # simply checking that runs with DataLoaderImbalanced assert "train_loss" in trainer.history.keys() diff --git a/tests/test_rec/test_ranking_metrics_integration.py b/tests/test_rec/test_ranking_metrics_integration.py index 1239ae54..81f8881b 100644 --- a/tests/test_rec/test_ranking_metrics_integration.py +++ b/tests/test_rec/test_ranking_metrics_integration.py @@ -1,5 +1,3 @@ -# silence DeprecationWarning -import warnings from typing import Tuple import numpy as np @@ -18,8 +16,8 @@ BinaryNDCG_at_k, ) from pytorch_widedeep.training import Trainer +from pytorch_widedeep.dataloaders import CustomDataLoader from pytorch_widedeep.preprocessing import TabPreprocessor -from pytorch_widedeep.training._wd_dataset import WideDeepDataset def generate_user_item_interactions( @@ -89,8 +87,8 @@ def user_item_data(request): # strategy 1: -# * ranking metrics for both training and validation sets. -# * Same number of items per user in both sets +# * Same ranking metrics for both training and validation sets. +# * Same number of items per user in both sets @pytest.mark.parametrize( @@ -100,11 +98,14 @@ def user_item_data(request): "metric", [ MAP_at_k(n_cols=5, k=3), - NDCG_at_k(n_cols=5, k=3), Recall_at_k(n_cols=5, k=3), HitRatio_at_k(n_cols=5, k=3), Precision_at_k(n_cols=5, k=3), BinaryNDCG_at_k(n_cols=5, k=3), + [Accuracy(), MAP_at_k(n_cols=5, k=3)], + [F1Score(), MAP_at_k(n_cols=5, k=3)], + [Accuracy(), BinaryNDCG_at_k(n_cols=5, k=3)], + [F1Score(), BinaryNDCG_at_k(n_cols=5, k=3)], ], ) def test_binary_classification_strategy_1(user_item_data, metric): @@ -131,13 +132,17 @@ def test_binary_classification_strategy_1(user_item_data, metric): ) model = WideDeep(deeptabular=tab_mlp) - trainer = Trainer(model, objective="binary", metrics=[metric]) + if isinstance(metric, list): + trainer = Trainer(model, objective="binary", metrics=metric) + else: + trainer = Trainer(model, objective="binary", metrics=[metric]) trainer.fit( X_train={"X_tab": X_train, "target": train_df[target_col].values}, X_val={"X_tab": X_val, "target": val_df[target_col].values}, n_epochs=1, batch_size=2 * 5, + verbose=0, ) # predict on validation, this is just a test... @@ -148,6 +153,322 @@ def test_binary_classification_strategy_1(user_item_data, metric): trainer.history is not None and "train_loss" in trainer.history and "val_loss" in trainer.history - and f"train_{metric._name}" in trainer.history - and f"val_{metric._name}" in trainer.history ) + if not isinstance(metric, list): + metric = [metric] + for m in metric: + assert f"train_{m._name}" in trainer.history + assert f"val_{m._name}" in trainer.history + + +@pytest.mark.parametrize( + "user_item_data", [{"validation_interactions_per_user": 5}], indirect=True +) +def test_multiclass_classification_strategy_1(user_item_data): + + train_df, val_df = user_item_data + + categorical_cols = ["user_id", "item_id", "user_category", "item_category"] + + target_col = "rating" + + tab_preprocessor = TabPreprocessor( + embed_cols=categorical_cols, + for_transformer=False, + ) + + tab_preprocessor.fit(train_df) + + X_train = tab_preprocessor.transform(train_df) + X_val = tab_preprocessor.transform(val_df) + + tab_mlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type] + ) + model = WideDeep(deeptabular=tab_mlp, pred_dim=3) + + trainer = Trainer( + model, + objective="multiclass", + metrics=[Accuracy(), NDCG_at_k(n_cols=5, k=3)], + verbose=0, + ) + + trainer.fit( + X_train={"X_tab": X_train, "target": train_df[target_col].values}, + X_val={"X_tab": X_val, "target": val_df[target_col].values}, + n_epochs=1, + batch_size=2 * 5, # 2 * n_items + ) + + # predict on validation, this is just a test... + preds = trainer.predict(X_tab=X_val) + + assert preds.shape[0] == X_val.shape[0] + assert ( + trainer.history is not None + and "train_loss" in trainer.history + and "val_loss" in trainer.history + ) + + +# strategy 2: +# * Diff metrics for training and validation sets. +# * Same number of items per user in both sets + + +@pytest.mark.parametrize( + "user_item_data", [{"validation_interactions_per_user": 5}], indirect=True +) +@pytest.mark.parametrize("target_col", ["liked", "rating"]) +def test_strategy_2(user_item_data, target_col): + + train_df, val_df = user_item_data + + categorical_cols = ["user_id", "item_id", "user_category", "item_category"] + + tab_preprocessor = TabPreprocessor( + embed_cols=categorical_cols, + for_transformer=False, + ) + + tab_preprocessor.fit(train_df) + + X_train = tab_preprocessor.transform(train_df) + X_val = tab_preprocessor.transform(val_df) + + tab_mlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type] + ) + model = WideDeep(deeptabular=tab_mlp, pred_dim=1 if target_col == "liked" else 3) + + eval_metrics = ( + [BinaryNDCG_at_k(n_cols=5, k=3)] + if target_col == "liked" + else [NDCG_at_k(n_cols=5, k=3)] + ) + trainer = Trainer( + model, + objective="binary" if target_col == "liked" else "multiclass", + metrics=[Accuracy()], + eval_metrics=eval_metrics, + verbose=0, + ) + + trainer.fit( + X_train={"X_tab": X_train, "target": train_df[target_col].values}, + X_val={"X_tab": X_val, "target": val_df[target_col].values}, + n_epochs=1, + batch_size=2 * 5, + ) + + # predict on validation, this is just a test... + preds = trainer.predict(X_tab=X_val) + + assert preds.shape[0] == X_val.shape[0] + assert ( + trainer.history is not None + and "train_loss" in trainer.history + and "val_loss" in trainer.history + and "train_acc" in trainer.history + and "val_binary_ndcg@3" in trainer.history + if target_col == "liked" + else "val_ndcg@3" in trainer.history + ) + + +# strategy 3: +# * Diff number of items per user in both sets which implies CustomDataLoaders +# * everything else + + +@pytest.mark.parametrize( + "user_item_data", [{"validation_interactions_per_user": 6}], indirect=True +) +@pytest.mark.parametrize( + "metrics", + [ + [[BinaryNDCG_at_k(n_cols=4, k=3)], [BinaryNDCG_at_k(n_cols=6, k=3)]], + [[Accuracy(), F1Score()], [F1Score(), BinaryNDCG_at_k(n_cols=6, k=3)]], + [[Accuracy(), F1Score()], [F1Score(), MAP_at_k(n_cols=6, k=3)]], + [[Accuracy(), F1Score()], [F1Score(), Precision_at_k(n_cols=6, k=3)]], + [[Accuracy(), F1Score()], [F1Score(), Recall_at_k(n_cols=6, k=3)]], + [[Accuracy(), F1Score()], [F1Score(), HitRatio_at_k(n_cols=6, k=3)]], + ], +) +@pytest.mark.parametrize("with_train_data_loader", [True, False]) +def test_binary_classification_strategy_3( + user_item_data, metrics, with_train_data_loader +): + + train_metrics, val_metrics = metrics[0], metrics[1] + + train_df, val_df = user_item_data + + categorical_cols = ["user_id", "item_id", "user_category", "item_category"] + + target_col = "liked" + + tab_preprocessor = TabPreprocessor( + embed_cols=categorical_cols, + for_transformer=False, + ) + + tab_preprocessor.fit(train_df) + + X_train = tab_preprocessor.transform(train_df) + X_val = tab_preprocessor.transform(val_df) + + tab_mlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type] + ) + model = WideDeep(deeptabular=tab_mlp) + + trainer = Trainer( + model, + objective="binary", + metrics=train_metrics, + eval_metrics=val_metrics, + ) + + if with_train_data_loader: + train_dl = CustomDataLoader( + batch_size=2 * 4, shuffle=False # in case there is a ranking metric + ) + else: + train_dl = None + + valid_dl = CustomDataLoader( + batch_size=2 * 6, + shuffle=False, + ) + + trainer.fit( + X_train={"X_tab": X_train, "target": train_df[target_col].values}, + X_val={"X_tab": X_val, "target": val_df[target_col].values}, + n_epochs=1, + batch_size=2 + * 4, # it only applies to the training set. Will be ignored if train_dl is not None + train_dataloader=train_dl, + eval_dataloader=valid_dl, + ) + + train_metric_names = [m._name for m in train_metrics] + val_metric_names = [m._name for m in val_metrics] + + # predict on validation, this is just a test... + preds = trainer.predict(X_tab=X_val) + + assert valid_dl.dataset.X_tab.shape[0] == X_val.shape[0] == 60 + + if with_train_data_loader: + assert train_dl.dataset.X_tab.shape[0] == X_train.shape[0] == 40 + + assert preds.shape[0] == X_val.shape[0] + assert ( + trainer.history is not None + and "train_loss" in trainer.history + and "val_loss" in trainer.history + ) + + for train_metric_name in train_metric_names: + assert f"train_{train_metric_name}" in trainer.history + + for val_metric_name in val_metric_names: + assert f"val_{val_metric_name}" in trainer.history + + +@pytest.mark.parametrize( + "user_item_data", [{"validation_interactions_per_user": 6}], indirect=True +) +@pytest.mark.parametrize( + "metrics", + [ + [[NDCG_at_k(n_cols=4, k=3)], [NDCG_at_k(n_cols=6, k=3)]], + [[Accuracy(), F1Score()], [F1Score(), NDCG_at_k(n_cols=6, k=3)]], + ], +) +@pytest.mark.parametrize("with_train_data_loader", [True, False]) +def test_multiclass_classification_strategy_3( + user_item_data, metrics, with_train_data_loader +): + + train_metrics, val_metrics = metrics[0], metrics[1] + + train_df, val_df = user_item_data + + categorical_cols = ["user_id", "item_id", "user_category", "item_category"] + + target_col = "rating" + + tab_preprocessor = TabPreprocessor( + embed_cols=categorical_cols, + for_transformer=False, + ) + + tab_preprocessor.fit(train_df) + + X_train = tab_preprocessor.transform(train_df) + X_val = tab_preprocessor.transform(val_df) + + tab_mlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type] + ) + model = WideDeep(deeptabular=tab_mlp, pred_dim=3) + + trainer = Trainer( + model, + objective="multiclass", + metrics=train_metrics, + eval_metrics=val_metrics, + ) + + if with_train_data_loader: + train_dl = CustomDataLoader( + batch_size=2 * 4, shuffle=False # in case there is a ranking metric + ) + else: + train_dl = None + + valid_dl = CustomDataLoader( + batch_size=2 * 6, + shuffle=False, + ) + + trainer.fit( + X_train={"X_tab": X_train, "target": train_df[target_col].values}, + X_val={"X_tab": X_val, "target": val_df[target_col].values}, + n_epochs=1, + batch_size=2 + * 4, # it only applies to the training set. Will be ignored if train_dl is not None + train_dataloader=train_dl, + eval_dataloader=valid_dl, + ) + + train_metric_names = [m._name for m in train_metrics] + val_metric_names = [m._name for m in val_metrics] + + # predict on validation, this is just a test... + preds = trainer.predict(X_tab=X_val) + + assert valid_dl.dataset.X_tab.shape[0] == X_val.shape[0] == 60 + + if with_train_data_loader: + assert train_dl.dataset.X_tab.shape[0] == X_train.shape[0] == 40 + + assert preds.shape[0] == X_val.shape[0] + assert ( + trainer.history is not None + and "train_loss" in trainer.history + and "val_loss" in trainer.history + ) + + for train_metric_name in train_metric_names: + assert f"train_{train_metric_name}" in trainer.history + + for val_metric_name in val_metric_names: + assert f"val_{val_metric_name}" in trainer.history From 808a7bbbedbc6908ca8c61966ce5e0488c391e6b Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 21 Oct 2024 10:33:48 +0100 Subject: [PATCH 12/24] remove unnecessary, legacy docs folder and modified figures path in README --- mkdocs/site/index.html | 3 +- mkdocs/site/objects.inv | Bin 1996 -> 2007 bytes .../pytorch-widedeep/bayesian_trainer.html | 33 +- mkdocs/site/pytorch-widedeep/dataloaders.html | 87 ++-- mkdocs/site/pytorch-widedeep/losses.html | 432 +++++++++-------- mkdocs/site/pytorch-widedeep/metrics.html | 72 +-- .../site/pytorch-widedeep/preprocessing.html | 297 ++++++++++++ mkdocs/site/pytorch-widedeep/preprocessing.md | 2 + mkdocs/site/pytorch-widedeep/tab2vec.html | 73 +-- mkdocs/site/pytorch-widedeep/trainer.html | 435 ++++++++---------- mkdocs/site/search/search_index.json | 2 +- mkdocs/site/sitemap.xml | 88 ++-- mkdocs/site/sitemap.xml.gz | Bin 796 -> 795 bytes 13 files changed, 869 insertions(+), 655 deletions(-) diff --git a/mkdocs/site/index.html b/mkdocs/site/index.html index 1f68a10a..666a04ab 100644 --- a/mkdocs/site/index.html +++ b/mkdocs/site/index.html @@ -2382,8 +2382,7 @@

APA¶< Javier, Pavol Mulinka, - Javier Rodriguez Zaurin, - Not Committed Yet + Javier Rodriguez Zaurin diff --git a/mkdocs/site/objects.inv b/mkdocs/site/objects.inv index 2818ee0861c603228b2f2d166adb8ebdc801c0d6..2208df9aa6290a7b647eaabf35faa393dd5d954b 100644 GIT binary patch delta 1881 zcmV-f2d4PU57!TnlYdQ)d&?x_N~#=J)=teV`~F~l|(+~*RLeNU)V-CYB84t zWa#&4b+-TsyL?xLpzHXRC6EA=(UQD_WQ50gkw7LRMV3`Y#I`+S9Cr83nbEgq?{B^H zseJ#FW<<(QU-hOn|G72)<F3KYjOT^TD`vvj+NH7L2P8TTX}~SwRvIGP>2< zQ@eeiFG)r?h2*Tj?#ZVm%n@-aiabt>OdlJuoUEY!yVrhiS|++-V)xHVcZ?3uZ`fH!6T?7OcU0*JpRvB@0?s^j{tKH;|h3kZW zDJ!4ZiIkNOi>Xh>DGKCY72%3y;0<)wpl?{OV-nV2H1`_!Iac7GUI7y3FYY_`e53wz zKb?8wY>jz_jPoi}OcAjH6+M{kC%b)Pcv;FQF#^84IHfj+l>n`Wrlc$TQ0bZ{zMafL z2}WgfO@C=62t7ZA&>rk=F4E+n`U3B6ck*>X3jv-e8*APnzkwoNIX4DNRaPoa zSZ>6KQo@f`eMVYY5j$f!sbtSAjyILiUPYxS{(lA9p9_Npb$RUV-?euc9s$&;(dKYnq=jSOI*Za!L70sd$Ex4V& z)_(=APu3E4spzI|y>1&HrjJNo4`PV6N4x@Ct(EjUHeCjVHB#=xwV z4nC@d-nBzf??%_{-h4L0@V`eAJsCF)^@O3(|7Nx@-R(-@Cb$)qiqL0B;{Co})U83H>`V|Y z=7qyjA5A)CYJV)bbVy7jArKOuu;N;*GB^%5mtdI%Ehm#`bbx z>zmccBju%=vGcE^{7Z-;roEGP96aD1icqQFPx1>9px$QKC$w%An#VQdf`4#S#;Xcf zfRYHsNEQbK8KwKj?D=Eb*rR|0MWFG>fy6t3ed$D2C5RU$Au%p8ZcZ_;1PLO|)X?VZ zpikz*p4-nugBoOG(=SHUn(pfh&q({IB6qPf2^ULvgqmpNI!`j z#L?rC8O`Y}&JgeGX*LnpGk<+z<9AOn;Cgvn9L=pU0@(v{#8>@5cl83)JjY%A0DAu2 z;pW)Wb$?*ZjNYRoZZ?jsd>hvRgbqU|(AWq0Eb5?+7X`Nix7}Cx4$#jvH65$>=^N zeFUtNe*XEa#D^G+?i2(9uGKh4+1hof8b=0~A~}$Csfl6g zy(Mr+;F$3PDO?i#1c_TdU&iB*zNe^8T_pUm?J5i*bg8F-NTjq7JK-Yu782k7d{tBmpdCgW%maSB!47PIs{xGZ*G*RMO(+t;!Ls_btpIIjWLx>CDVJSgt0Qn~qkss4hQ z1@1u3cEwQrv@8Af=CWd05<|X(B+<8VGR}z-?9EMc$IbxB6(~FMOH7fHB(a45D$sYv Ts6x_)E*=`LV`BdUpl4e@;*O~J delta 1870 zcmV-U2eJ6q56lmclYb@0y=9WPk}4;awNtZId!b~}*xG<#C6SN$^(zVR7q$_OTFfN@ z8Tx%%-7P@EE+cK8R2nWB#_BSkwujev1!j3hwXiHX7s(;`&;jP zDL?+C8IkhKh2FI0zqaPT?VDezJp28RFJFHv2ZQ9<)(<^&8*?lX`UW#=B$DH|{N`dYBS7VE=hz?wa)tJ}LB=>5X@7 z4Fz)+PcQtUwic0aA3T=Zu1O@|83OSP3Hh#Cc&xOwHGkij_dG2gul@DvmObC79;*wd ztbArCQdU09r#=~{D3E_uge#VTH_&Z^zG1zNNmzr?%xm0dSb=|j14x*^xNq6>jr!01 z^wAqi1!z4qC0*J3N>@Db?PLy0 zFe;-fN`EUs=;xtXHd zY@h5+g(v5NuJ;trr zJy)P>Sv`V!E5wU;o%hDE+jk`(hamp7)631Oj$O>@UBnrlpT}fe?<>z&G>byC;CA|2 z7k{`uSxeZZqU*Z#s%?CjJ|THMh#}e>@dj+UQqu3(bRHDeNV!cTAxrmA=DI0RfE?6A zhQR7@I0+OyG$jZM?6n4YYIi~f=CJx`VL5^mJSaQ! z5@ZqhqH&sFwBVdTLJN?2nQK)omy9n{LLuIPTBY^dkkQ@5PFwko%^>kkSP*(baDOga zomYtUZgkb|%^n{y{O^%OPsTMvJz;3{znd*gce_%!32sHDBJ>54c(K1GQ}^V9!rv!8Bz+?-%8lwuwP3+u6GyN5P9gt zfy`hZa5`hWdC?;aG(e@9yySBC$MjT$f^YK$|NMlMaIo3=KUN&q?xB#e>d~leAsjQxo=Q| ztZn+mh+5Nqec>5t_iLXc^vOYhv3i4U0u(20gzAqqlv5U6o&o76k%KsTI5ML-y}=pc zeLc-4;(De}to`mO23#)>i+`iJHAWzNKo0n-ALuS$ftu&I%O5~bzsTDhd%W)Vthw9_ zoMqJg+B)34Pt$E&);|H`F#KEjVDP0`Jf`ky1+{m;8G-evMx0-Oue@R_6Zy_ejtTQ zf*&Dq%jffW9Mbm~)v=3&Kela!A%rgVG!W@n`Iw&ukoJp=?eDtPEmN|9ta-3FdDmhO z|JF5d{KDrV>2ha-r+@7Jpn=3Ncrp=*W59ei5&?P-n773V+w`=Io{JaYy-Vn|Uia)z z=ta$JY}|bDzu#RPd9&gaPX9+i$s$l!Tdjx+APNF|=su$PlRR@6@dv=G90a2z)A(^b diff --git a/mkdocs/site/pytorch-widedeep/bayesian_trainer.html b/mkdocs/site/pytorch-widedeep/bayesian_trainer.html index 51e79f11..2b758b4b 100644 --- a/mkdocs/site/pytorch-widedeep/bayesian_trainer.html +++ b/mkdocs/site/pytorch-widedeep/bayesian_trainer.html @@ -1755,7 +1755,8 @@