From 1eaaded0eea36e59e3d20f94efb6e7e0df053aea Mon Sep 17 00:00:00 2001
From: Javier
Date: Sat, 5 Oct 2024 22:24:43 +0100
Subject: [PATCH 01/24] bayesian functionalities updated to support MPS backend
---
.../_base_contrastive_denoising_trainer.py | 3 ++-
.../_base_encoder_decoder_trainer.py | 3 ++-
pytorch_widedeep/tab2vec.py | 14 +++++++++++---
.../training/_base_bayesian_trainer.py | 3 ++-
pytorch_widedeep/training/bayesian_trainer.py | 18 ++++++++++++++----
pytorch_widedeep/training/trainer.py | 3 ++-
.../training/trainer_from_folder.py | 5 +++--
pytorch_widedeep/utils/general_utils.py | 11 +++++++++++
.../test_b_miscellaneous.py | 6 ++++--
.../test_bayes_tab2vec/test_b_t2v.py | 5 +++--
.../test_mc_attn_layers.py | 3 ++-
tests/test_tab2vec/test_t2v.py | 4 ++--
12 files changed, 58 insertions(+), 20 deletions(-)
diff --git a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
index 95706b50..11c9b835 100644
--- a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
+++ b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
@@ -27,6 +27,7 @@
CallbackContainer,
LRShedulerCallback,
)
+from pytorch_widedeep.utils.general_utils import setup_device
from pytorch_widedeep.models.tabular.self_supervised import ContrastiveDenoisingModel
from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor
@@ -270,7 +271,7 @@ def _set_device_and_num_workers(**kwargs):
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
- default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ default_device = setup_device()
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
diff --git a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
index fe0aa4a8..97491b4f 100644
--- a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
+++ b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
@@ -25,6 +25,7 @@
CallbackContainer,
LRShedulerCallback,
)
+from pytorch_widedeep.utils.general_utils import setup_device
from pytorch_widedeep.models.tabular.self_supervised import EncoderDecoderModel
@@ -245,7 +246,7 @@ def _set_device_and_num_workers(**kwargs):
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
- default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ default_device = setup_device()
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
diff --git a/pytorch_widedeep/tab2vec.py b/pytorch_widedeep/tab2vec.py
index af3d9273..5ba2ab48 100644
--- a/pytorch_widedeep/tab2vec.py
+++ b/pytorch_widedeep/tab2vec.py
@@ -10,8 +10,6 @@
from pytorch_widedeep.bayesian_models import BayesianWide, BayesianTabMlp
from pytorch_widedeep.bayesian_models._base_bayesian_model import BaseBayesianModel
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
class Tab2Vec:
r"""Class to transform an input dataframe into vectorized form.
@@ -87,6 +85,7 @@ def __init__(
self,
tab_preprocessor: TabPreprocessor,
model: Union[WideDeep, BayesianWide, BayesianTabMlp],
+ device: str | torch.device,
return_dataframe: bool = False,
verbose: bool = False,
):
@@ -94,6 +93,11 @@ def __init__(
self._check_inputs(tab_preprocessor, model, verbose)
+ if isinstance(device, str):
+ self.device = torch.device(device)
+ else:
+ self.device = device
+
self.tab_preprocessor = tab_preprocessor
self.return_dataframe = return_dataframe
self.verbose = verbose
@@ -158,7 +162,11 @@ def transform(
"""
X_tab = self.tab_preprocessor.transform(df)
- X = torch.from_numpy(X_tab.astype("float")).to(device)
+
+ if self.device.type == "mps":
+ X = torch.from_numpy(X_tab.astype("float32")).to(self.device)
+ else:
+ X = torch.from_numpy(X_tab.astype("float")).to(self.device)
with torch.no_grad():
if self.is_tab_transformer:
diff --git a/pytorch_widedeep/training/_base_bayesian_trainer.py b/pytorch_widedeep/training/_base_bayesian_trainer.py
index 02993ef9..fa508531 100644
--- a/pytorch_widedeep/training/_base_bayesian_trainer.py
+++ b/pytorch_widedeep/training/_base_bayesian_trainer.py
@@ -25,6 +25,7 @@
CallbackContainer,
LRShedulerCallback,
)
+from pytorch_widedeep.utils.general_utils import setup_device
from pytorch_widedeep.training._trainer_utils import bayesian_alias_to_loss
from pytorch_widedeep.bayesian_models._base_bayesian_model import BaseBayesianModel
@@ -235,7 +236,7 @@ def _set_device_and_num_workers(**kwargs):
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
- default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ default_device = setup_device()
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
diff --git a/pytorch_widedeep/training/bayesian_trainer.py b/pytorch_widedeep/training/bayesian_trainer.py
index 88c426fe..65020c23 100644
--- a/pytorch_widedeep/training/bayesian_trainer.py
+++ b/pytorch_widedeep/training/bayesian_trainer.py
@@ -86,7 +86,8 @@ class BayesianTrainer(BaseBayesianTrainer):
Other infrequently used arguments that can also be passed as kwargs are:
- **device**: `str`
- string indicating the device. One of _'cpu'_ or _'gpu'_
+ string indicating the device. One of _'cpu'_, _'gpu'_ or 'mps' if
+ run on a Mac with Apple silicon or AMD GPU(s)
- **num_workers**: `int`
number of workers to be used internally by the data loaders
@@ -396,7 +397,10 @@ def _train_step(
):
self.model.train()
- X = X_tab.to(self.device)
+ try:
+ X = X_tab.to(self.device)
+ except TypeError:
+ X = X_tab.to(self.device, dtype=torch.float32)
y = target.view(-1, 1).float() if self.objective != "multiclass" else target
y = y.to(self.device)
@@ -424,7 +428,10 @@ def _eval_step(
):
self.model.eval()
with torch.no_grad():
- X = X_tab.to(self.device)
+ try:
+ X = X_tab.to(self.device)
+ except TypeError:
+ X = X_tab.to(self.device, dtype=torch.float32)
y = target.view(-1, 1).float() if self.objective != "multiclass" else target
y = y.to(self.device)
@@ -479,7 +486,10 @@ def _predict( # noqa: C901
for _, Xl in zip(tt, test_loader):
tt.set_description("predict")
- X = Xl[0].to(self.device)
+ try:
+ X = Xl[0].to(self.device)
+ except TypeError:
+ X = Xl[0].to(self.device, dtype=torch.float32)
if return_samples:
preds = torch.stack([self.model(X) for _ in range(n_samples)])
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index f4dd1a91..873a11b4 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -151,7 +151,8 @@ class Trainer(BaseTrainer):
Other infrequently used arguments that can also be passed as kwargs are:
- **device**: `str`
- string indicating the device. One of _'cpu'_ or _'gpu'_
+ string indicating the device. One of _'cpu'_, _'gpu'_ or 'mps' if
+ run on a Mac with Apple silicon or AMD GPU(s)
- **num_workers**: `int`
number of workers to be used internally by the data loaders
diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py
index e2159b71..0641f3d8 100644
--- a/pytorch_widedeep/training/trainer_from_folder.py
+++ b/pytorch_widedeep/training/trainer_from_folder.py
@@ -156,8 +156,9 @@ class TrainerFromFolder(BaseTrainer):
**kwargs: dict
Other infrequently used arguments that can also be passed as kwargs are:
- - **device**: `str`
- string indicating the device. One of _'cpu'_ or _'gpu'_
+ - **device**: `str`
+ string indicating the device. One of _'cpu'_, _'gpu'_ or 'mps' if
+ run on a Mac with Apple silicon or AMD GPU(s)
- **num_workers**: `int`
number of workers to be used internally by the data loaders
diff --git a/pytorch_widedeep/utils/general_utils.py b/pytorch_widedeep/utils/general_utils.py
index 6db4b4b3..cac80e84 100644
--- a/pytorch_widedeep/utils/general_utils.py
+++ b/pytorch_widedeep/utils/general_utils.py
@@ -1,5 +1,16 @@
from functools import wraps
+import torch
+
+
+def setup_device():
+ if torch.cuda.is_available():
+ return torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ return torch.device("mps")
+ else:
+ return torch.device("cpu")
+
def alias(original_name: str, alternative_names: list):
def decorator(func):
diff --git a/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py b/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py
index 5b472a21..9e0c6629 100644
--- a/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py
+++ b/tests/test_bayesian_models/test_bayes_model_functioning/test_b_miscellaneous.py
@@ -66,7 +66,8 @@ def test_save_and_load(model):
)
new_model = torch.load(
- "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt"
+ "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt",
+ weights_only=False,
)
if model.__class__.__name__ == "BayesianWide":
@@ -113,7 +114,8 @@ def test_save_and_load_dict(model_name):
btrainer2.model.load_state_dict(
torch.load(
- "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt"
+ "tests/test_bayesian_models/test_bayes_model_functioning/model_dir/bayesian_model.pt",
+ weights_only=False,
)
)
diff --git a/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py b/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py
index 17d71f9f..ee82b931 100644
--- a/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py
+++ b/tests/test_bayesian_models/test_bayes_tab2vec/test_b_t2v.py
@@ -2,15 +2,15 @@
from random import choices
import numpy as np
-import torch
import pandas as pd
import pytest
from pytorch_widedeep import Tab2Vec
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.bayesian_models import BayesianTabMlp
+from pytorch_widedeep.utils.general_utils import setup_device
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = setup_device()
colnames = list(string.ascii_lowercase)[:4] + ["target"]
cat_col1_vals = ["a", "b", "c"]
@@ -55,6 +55,7 @@ def test_bayesian_mlp_models(return_dataframe, embed_continuous):
t2v = Tab2Vec(
tab_preprocessor=tab_preprocessor,
model=model,
+ device=device,
return_dataframe=return_dataframe,
)
t2v_out, _ = t2v.transform(df_t2v, target_col="target")
diff --git a/tests/test_model_components/test_mc_attn_layers.py b/tests/test_model_components/test_mc_attn_layers.py
index e5106993..ead4b5fc 100644
--- a/tests/test_model_components/test_mc_attn_layers.py
+++ b/tests/test_model_components/test_mc_attn_layers.py
@@ -4,6 +4,7 @@
import torch
import pytest
+from pytorch_widedeep.utils.general_utils import setup_device
from pytorch_widedeep.models.tabular.transformers._attention_layers import (
MultiHeadedAttention,
)
@@ -12,7 +13,7 @@
input_dim = 128
n_heads = 4
-device = "cuda" if torch.cuda.is_available() else "cpu"
+device = setup_device()
standard_attn = (
MultiHeadedAttention(
diff --git a/tests/test_tab2vec/test_t2v.py b/tests/test_tab2vec/test_t2v.py
index a466aadf..5a92b9e7 100644
--- a/tests/test_tab2vec/test_t2v.py
+++ b/tests/test_tab2vec/test_t2v.py
@@ -2,7 +2,6 @@
from random import choices
import numpy as np
-import torch
import pandas as pd
import pytest
@@ -21,8 +20,9 @@
ContextAttentionMLP,
)
from pytorch_widedeep.preprocessing import TabPreprocessor
+from pytorch_widedeep.utils.general_utils import setup_device
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = setup_device()
colnames = list(string.ascii_lowercase)[:4] + ["target"]
cat_col1_vals = ["a", "b", "c"]
From 69a1fab263e6fd070239d1435a0c375610c3bfcd Mon Sep 17 00:00:00 2001
From: Javier
Date: Tue, 8 Oct 2024 18:10:26 +0100
Subject: [PATCH 02/24] Support for mps backend added. All unit test passing.
Need to run some example scripts
---
pytorch_widedeep/losses.py | 26 ++++++++--
.../training/_base_bayesian_trainer.py | 8 ++--
pytorch_widedeep/training/_base_trainer.py | 11 +++--
.../training/_feature_importance.py | 8 ++--
pytorch_widedeep/training/_finetune.py | 15 +++---
pytorch_widedeep/training/bayesian_trainer.py | 16 ++-----
pytorch_widedeep/training/trainer.py | 18 +++----
.../training/trainer_from_folder.py | 18 +++----
pytorch_widedeep/utils/general_utils.py | 47 +++++++++++++++++--
.../test_finetune/test_finetuning_routines.py | 16 +++----
tests/test_hf_integration/test_models.py | 1 +
tests/test_tab2vec/test_t2v.py | 4 ++
12 files changed, 118 insertions(+), 70 deletions(-)
diff --git a/pytorch_widedeep/losses.py b/pytorch_widedeep/losses.py
index 2674f1ed..5f5a5769 100644
--- a/pytorch_widedeep/losses.py
+++ b/pytorch_widedeep/losses.py
@@ -272,8 +272,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
else:
num_class = input_prob.size(1)
binary_target = torch.eye(num_class)[target.squeeze().cpu().long()]
- if use_cuda:
- binary_target = binary_target.cuda()
+ binary_target = binary_target.to(input.device)
binary_target = binary_target.contiguous()
weight = self._get_weight(input_prob, binary_target)
@@ -430,9 +429,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
# when using max the two input tensors (input and other) have to be of
# the same type
max_input = F.softplus(input[..., 2:])
- max_other = torch.sqrt(torch.Tensor([torch.finfo(torch.double).eps])).type(
- max_input.type()
- )
+ max_other = self.get_eps(max_input)
scale = torch.max(max_input, max_other)
safe_labels = positive * target + (1 - positive) * torch.ones_like(target)
@@ -446,6 +443,25 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
return torch.mean(classification_loss + regression_loss)
+ @staticmethod
+ def get_eps(max_input: Tensor) -> Tensor:
+ if max_input.device.type == "mps":
+ # For MPS, use float32 and then convert to the input type
+ eps = torch.finfo(torch.float32).eps
+ max_other = (
+ torch.sqrt(torch.tensor([eps], device="cpu"))
+ .to(max_input.device)
+ .to(max_input.dtype)
+ )
+ else:
+ # For other devices, use the original approach
+ eps = torch.finfo(torch.double).eps
+ max_other = (
+ torch.sqrt(torch.tensor([eps])).to(max_input.device).to(max_input.dtype)
+ )
+
+ return max_other
+
class L1Loss(nn.Module):
r"""L1 loss"""
diff --git a/pytorch_widedeep/training/_base_bayesian_trainer.py b/pytorch_widedeep/training/_base_bayesian_trainer.py
index fa508531..bbdf8255 100644
--- a/pytorch_widedeep/training/_base_bayesian_trainer.py
+++ b/pytorch_widedeep/training/_base_bayesian_trainer.py
@@ -12,6 +12,7 @@
from pytorch_widedeep.wdtypes import (
Any,
List,
+ Tuple,
Union,
Module,
Optional,
@@ -25,7 +26,7 @@
CallbackContainer,
LRShedulerCallback,
)
-from pytorch_widedeep.utils.general_utils import setup_device
+from pytorch_widedeep.utils.general_utils import setup_device, to_device_model
from pytorch_widedeep.training._trainer_utils import bayesian_alias_to_loss
from pytorch_widedeep.bayesian_models._base_bayesian_model import BaseBayesianModel
@@ -58,8 +59,7 @@ def __init__(
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.early_stop = False
- self.model = model
- self.model.to(self.device)
+ self.model = to_device_model(model, self.device)
self.verbose = verbose
self.seed = seed
@@ -230,7 +230,7 @@ def _set_callbacks_and_metrics(self, callbacks: Any, metrics: Any):
self.callback_container.set_trainer(self)
@staticmethod
- def _set_device_and_num_workers(**kwargs):
+ def _set_device_and_num_workers(**kwargs) -> Tuple[str, int]:
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py
index 74a57202..45cb2e9e 100644
--- a/pytorch_widedeep/training/_base_trainer.py
+++ b/pytorch_widedeep/training/_base_trainer.py
@@ -15,6 +15,7 @@
Any,
Dict,
List,
+ Tuple,
Union,
Module,
Optional,
@@ -31,6 +32,7 @@
LRShedulerCallback,
)
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
+from pytorch_widedeep.utils.general_utils import setup_device, to_device_model
from pytorch_widedeep.training._trainer_utils import alias_to_loss
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
@@ -74,10 +76,9 @@ def __init__(
self.verbose = verbose
self.seed = seed
- self.model = model
+ self.model = to_device_model(model, self.device)
if self.model.is_tabnet:
self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3)
- self.model.to(self.device)
self.model.wd_device = self.device
self.objective = objective
@@ -406,7 +407,7 @@ def _check_inputs(
)
@staticmethod
- def _set_device_and_num_workers(**kwargs):
+ def _set_device_and_num_workers(**kwargs) -> Tuple[str, int]:
# Important note for Mac users: Since python 3.8, the multiprocessing
# library start method changed from 'fork' to 'spawn'. This affects the
# data-loaders, which will not run in parallel.
@@ -415,9 +416,9 @@ def _set_device_and_num_workers(**kwargs):
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
- default_device = "cuda" if torch.cuda.is_available() else "cpu"
- device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
+ default_device = setup_device()
+ device = kwargs.get("device", default_device)
return device, num_workers
def __repr__(self) -> str: # noqa: C901
diff --git a/pytorch_widedeep/training/_feature_importance.py b/pytorch_widedeep/training/_feature_importance.py
index 79a37f19..8bf9588b 100644
--- a/pytorch_widedeep/training/_feature_importance.py
+++ b/pytorch_widedeep/training/_feature_importance.py
@@ -24,7 +24,7 @@
SelfAttentionMLP,
ContextAttentionMLP,
)
-from pytorch_widedeep.utils.general_utils import alias
+from pytorch_widedeep.utils.general_utils import alias, to_device
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix
@@ -136,7 +136,7 @@ def _sample_data(self, loader: DataLoader) -> Tensor:
batches = []
for i, (data, _) in enumerate(loader):
if i < n_iterations:
- batches.append(data["deeptabular"].to(self.device))
+ batches.append(to_device(data["deeptabular"], self.device))
else:
break
@@ -307,7 +307,7 @@ def explain(
m_explain_l = []
for batch_nb, data in enumerate(loader):
- X = data["deeptabular"].to(self.device)
+ X = to_device(data["deeptabular"], self.device)
M_explain, masks = model_backbone.forward_masks(X) # type: ignore[operator]
m_explain_l.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), reducing_matrix)
@@ -363,7 +363,7 @@ def explain(
batch_feat_imp: Any = []
for _, data in enumerate(loader):
- X = data["deeptabular"].to(self.device)
+ X = to_device(data["deeptabular"], self.device)
_ = model.deeptabular(X)
feat_imp, _ = self._feature_importance(model)
diff --git a/pytorch_widedeep/training/_finetune.py b/pytorch_widedeep/training/_finetune.py
index 90579179..5567fc42 100644
--- a/pytorch_widedeep/training/_finetune.py
+++ b/pytorch_widedeep/training/_finetune.py
@@ -16,10 +16,9 @@
DataLoader,
LRScheduler,
)
+from pytorch_widedeep.utils.general_utils import to_device, setup_device
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
-use_cuda = torch.cuda.is_available()
-
WDModel = Union[nn.Module, BaseWDModelComponent]
@@ -67,6 +66,8 @@ def __init__(
self.method = method
self.verbose = verbose
+ self.device = setup_device()
+
def finetune_all(
self,
model: Union[WDModel, nn.ModuleList],
@@ -432,20 +433,16 @@ def _train( # noqa: C901
data, target = packed_data
if idx is not None:
- X = (
- data[model_name][idx].cuda()
- if use_cuda
- else data[model_name][idx]
- )
+ X = to_device(data[model_name][idx], self.device)
else:
- X = data[model_name].cuda() if use_cuda else data[model_name]
+ X = to_device(data[model_name], self.device)
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "qregression"]
else target
)
- y = y.cuda() if use_cuda else y
+ y = to_device(y, self.device)
optimizer.zero_grad()
y_pred = model(X)
diff --git a/pytorch_widedeep/training/bayesian_trainer.py b/pytorch_widedeep/training/bayesian_trainer.py
index 65020c23..3c314934 100644
--- a/pytorch_widedeep/training/bayesian_trainer.py
+++ b/pytorch_widedeep/training/bayesian_trainer.py
@@ -20,7 +20,7 @@
LRScheduler,
)
from pytorch_widedeep.callbacks import Callback
-from pytorch_widedeep.utils.general_utils import alias
+from pytorch_widedeep.utils.general_utils import alias, to_device
from pytorch_widedeep.training._trainer_utils import (
save_epoch_logs,
print_loss_and_metric,
@@ -397,12 +397,9 @@ def _train_step(
):
self.model.train()
- try:
- X = X_tab.to(self.device)
- except TypeError:
- X = X_tab.to(self.device, dtype=torch.float32)
+ X = to_device(X_tab, self.device)
y = target.view(-1, 1).float() if self.objective != "multiclass" else target
- y = y.to(self.device)
+ y = to_device(y, self.device)
self.optimizer.zero_grad()
y_pred, loss = self.model.sample_elbo(X, y, self.loss_fn, n_samples, n_batches) # type: ignore[arg-type]
@@ -428,12 +425,9 @@ def _eval_step(
):
self.model.eval()
with torch.no_grad():
- try:
- X = X_tab.to(self.device)
- except TypeError:
- X = X_tab.to(self.device, dtype=torch.float32)
+ X = to_device(X_tab, self.device)
y = target.view(-1, 1).float() if self.objective != "multiclass" else target
- y = y.to(self.device)
+ y = to_device(y, self.device)
y_pred, loss = self.model.sample_elbo(
X, # type: ignore[arg-type]
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index 873a11b4..8ca7d459 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -26,7 +26,7 @@
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.initializers import Initializer
from pytorch_widedeep.training._finetune import FineTune
-from pytorch_widedeep.utils.general_utils import alias
+from pytorch_widedeep.utils.general_utils import alias, to_device
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training._base_trainer import BaseTrainer
from pytorch_widedeep.training._trainer_utils import (
@@ -914,15 +914,15 @@ def _train_step(
X: Dict[str, Union[Tensor, List[Tensor]]] = {}
for k, v in data.items():
if isinstance(v, list):
- X[k] = [i.to(self.device) for i in v]
+ X[k] = [to_device(i, self.device) for i in v]
else:
- X[k] = v.to(self.device)
+ X[k] = to_device(v, self.device)
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "qregression", "multitarget"]
else target
)
- y = y.to(self.device)
+ y = to_device(y, self.device)
self.optimizer.zero_grad()
@@ -954,15 +954,15 @@ def _eval_step(
X: Dict[str, Union[Tensor, List[Tensor]]] = {}
for k, v in data.items():
if isinstance(v, list):
- X[k] = [i.to(self.device) for i in v]
+ X[k] = [to_device(i, self.device) for i in v]
else:
- X[k] = v.to(self.device)
+ X[k] = to_device(v, self.device)
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "qregression", "multitarget"]
else target
)
- y = y.to(self.device)
+ y = to_device(y, self.device)
y_pred = self.model(X)
if self.model.is_tabnet:
@@ -1061,9 +1061,9 @@ def _predict( # type: ignore[override, return] # noqa: C901
X: Dict[str, Union[Tensor, List[Tensor]]] = {}
for k, v in data.items():
if isinstance(v, list):
- X[k] = [i.to(self.device) for i in v]
+ X[k] = [to_device(i, self.device) for i in v]
else:
- X[k] = v.to(self.device)
+ X[k] = to_device(v, self.device)
preds = (
self.model(X)
if not self.model.is_tabnet
diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py
index 0641f3d8..1ebc2629 100644
--- a/pytorch_widedeep/training/trainer_from_folder.py
+++ b/pytorch_widedeep/training/trainer_from_folder.py
@@ -23,7 +23,7 @@
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.initializers import Initializer
from pytorch_widedeep.training._finetune import FineTune
-from pytorch_widedeep.utils.general_utils import alias
+from pytorch_widedeep.utils.general_utils import alias, to_device
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training._base_trainer import BaseTrainer
from pytorch_widedeep.training._trainer_utils import (
@@ -525,15 +525,15 @@ def _train_step(
X: Dict[str, Union[Tensor, List[Tensor]]] = {}
for k, v in data.items():
if isinstance(v, list):
- X[k] = [i.to(self.device) for i in v]
+ X[k] = [to_device(i, self.device) for i in v]
else:
- X[k] = v.to(self.device)
+ X[k] = to_device(v, self.device)
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "qregression"]
else target
)
- y = y.to(self.device)
+ y = to_device(y, self.device)
self.optimizer.zero_grad()
@@ -560,15 +560,15 @@ def _eval_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
X: Dict[str, Union[Tensor, List[Tensor]]] = {}
for k, v in data.items():
if isinstance(v, list):
- X[k] = [i.to(self.device) for i in v]
+ X[k] = [to_device(i, self.device) for i in v]
else:
- X[k] = v.to(self.device)
+ X[k] = to_device(v, self.device)
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "qregression"]
else target
)
- y = y.to(self.device)
+ y = to_device(y, self.device)
y_pred = self.model(X)
if self.model.is_tabnet: # pragma: no cover
@@ -670,9 +670,9 @@ def _predict( # noqa: C901
X: Dict[str, Union[Tensor, List[Tensor]]] = {}
for k, v in data.items():
if isinstance(v, list):
- X[k] = [i.to(self.device) for i in v]
+ X[k] = [to_device(i, self.device) for i in v]
else:
- X[k] = v.to(self.device)
+ X[k] = to_device(v, self.device)
preds = (
self.model(X)
if not self.model.is_tabnet
diff --git a/pytorch_widedeep/utils/general_utils.py b/pytorch_widedeep/utils/general_utils.py
index cac80e84..781439ef 100644
--- a/pytorch_widedeep/utils/general_utils.py
+++ b/pytorch_widedeep/utils/general_utils.py
@@ -1,15 +1,54 @@
from functools import wraps
import torch
+from torch import Tensor
-def setup_device():
+def setup_device() -> str:
if torch.cuda.is_available():
- return torch.device("cuda")
+ return "cuda"
elif torch.backends.mps.is_available():
- return torch.device("mps")
+ return "mps"
else:
- return torch.device("cpu")
+ return "cpu"
+
+
+def to_device(X: Tensor, device: str) -> Tensor:
+ # Adjustmet in case the backend is mps which does not support float64
+ if device == "mps" and X.dtype == torch.float64:
+ X = X.float()
+ return X.to(device)
+
+
+def to_device_model(model, device: str): # noqa: C901
+ # insistent transformation since it some cases overall approaches such as
+ # model.to('mps') do not work
+
+ if device in ["cpu", "cuda"]:
+ return model.to(device)
+
+ if device == "mps":
+
+ try:
+ return model.to(device)
+ except (RuntimeError, TypeError):
+
+ def convert(t):
+ if isinstance(t, torch.Tensor):
+ return t.float().to(device)
+ return t
+
+ model.apply(lambda module: module._apply(convert))
+
+ for param in model.parameters():
+ if param.device.type != "mps":
+ param.data = param.data.float().to(device)
+
+ for buffer_name, buffer in model.named_buffers():
+ if buffer.device.type != "mps":
+ model._buffers[buffer_name] = buffer.float().to(device)
+
+ return model
def alias(original_name: str, alternative_names: list):
diff --git a/tests/test_finetune/test_finetuning_routines.py b/tests/test_finetune/test_finetuning_routines.py
index b64eacc5..7e03f8c8 100644
--- a/tests/test_finetune/test_finetuning_routines.py
+++ b/tests/test_finetune/test_finetuning_routines.py
@@ -1,7 +1,6 @@
import string
import numpy as np
-import torch
import pytest
import torch.nn.functional as F
from torch import nn
@@ -10,10 +9,11 @@
from pytorch_widedeep.models import Wide, TabMlp
from pytorch_widedeep.metrics import Accuracy, MultipleMetrics
from pytorch_widedeep.training._finetune import FineTune
+from pytorch_widedeep.utils.general_utils import setup_device, to_device_model
from pytorch_widedeep.models.image._layers import conv_layer
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
-use_cuda = torch.cuda.is_available()
+device = setup_device()
# Define a series of simple models to quickly test the FineTune class
@@ -87,8 +87,7 @@ def loss_fn(y_pred, y_true):
# wide/linear
wide = Wide(np.unique(X_wide).size, 1)
-if use_cuda:
- wide.cuda()
+wide = to_device_model(wide, device)
# tabular
tab_mlp = TabMlp(
@@ -99,18 +98,15 @@ def loss_fn(y_pred, y_true):
mlp_dropout=0.2,
)
tab_mlp = nn.Sequential(tab_mlp, nn.Linear(8, 1)) # type: ignore[assignment]
-if use_cuda:
- tab_mlp.cuda()
+tab_mlp = to_device_model(tab_mlp, device)
# text
deeptext = TextModeTestClass()
-if use_cuda:
- deeptext.cuda()
+deeptext = to_device_model(deeptext, device)
# image
deepimage = ImageModeTestClass()
-if use_cuda:
- deepimage.cuda()
+ddeimage = to_device_model(deepimage, device)
# Define the loader
mmset = WideDeepDataset(X_wide, X_tab, X_text, X_image, target)
diff --git a/tests/test_hf_integration/test_models.py b/tests/test_hf_integration/test_models.py
index 847731d8..0c86039e 100644
--- a/tests/test_hf_integration/test_models.py
+++ b/tests/test_hf_integration/test_models.py
@@ -161,6 +161,7 @@ def test_full_training_process(model_name):
model,
objective="binary",
verbose=0,
+ # device="cpu",
)
trainer.fit(
diff --git a/tests/test_tab2vec/test_t2v.py b/tests/test_tab2vec/test_t2v.py
index 5a92b9e7..cc9b2746 100644
--- a/tests/test_tab2vec/test_t2v.py
+++ b/tests/test_tab2vec/test_t2v.py
@@ -85,6 +85,7 @@ def test_non_transformer_models(deeptabular, return_dataframe):
# Let's assume the model is trained
t2v = Tab2Vec(
tab_preprocessor=tab_preprocessor,
+ device=device,
model=model,
return_dataframe=return_dataframe,
)
@@ -163,6 +164,7 @@ def test_tab_transformer_models(
t2v = Tab2Vec(
tab_preprocessor=tab_preprocessor,
model=model,
+ device=device,
)
x_vec = t2v.transform(df_t2v)
@@ -229,6 +231,7 @@ def test_attentive_mlp(
t2v = Tab2Vec(
tab_preprocessor=tab_preprocessor,
model=model,
+ device=device,
)
x_vec = t2v.transform(df_t2v)
@@ -305,6 +308,7 @@ def test_transformer_family_models(
tab_preprocessor=tab_preprocessor,
model=model,
return_dataframe=return_dataframe,
+ device=device,
)
x_vec = t2v.transform(df_t2v)
From 73894d4a45de727c47005a1b0a7a9252ece8a9b1 Mon Sep 17 00:00:00 2001
From: Javier
Date: Fri, 11 Oct 2024 18:25:58 +0100
Subject: [PATCH 03/24] The 1st version of the DINPreprocessor seems to work.
Need to test it
---
pytorch_widedeep/models/rec/din.py | 2 +-
pytorch_widedeep/preprocessing/__init__.py | 1 +
.../preprocessing/din_preprocessor.py | 451 ++++++++++++++++++
3 files changed, 453 insertions(+), 1 deletion(-)
create mode 100644 pytorch_widedeep/preprocessing/din_preprocessor.py
diff --git a/pytorch_widedeep/models/rec/din.py b/pytorch_widedeep/models/rec/din.py
index 1eba4f97..293f249e 100644
--- a/pytorch_widedeep/models/rec/din.py
+++ b/pytorch_widedeep/models/rec/din.py
@@ -242,8 +242,8 @@ def __init__(
self,
*,
column_idx: Dict[str, int],
- target_item_col: str,
user_behavior_confiq: Tuple[List[str], int, int],
+ target_item_col: str = "target_item",
action_seq_config: Optional[Tuple[List[str], int]] = None,
other_seq_cols_confiq: Optional[List[Tuple[List[str], int, int]]] = None,
attention_unit_activation: Literal["prelu", "dice"] = "prelu",
diff --git a/pytorch_widedeep/preprocessing/__init__.py b/pytorch_widedeep/preprocessing/__init__.py
index 59f8b787..b4c5bcee 100644
--- a/pytorch_widedeep/preprocessing/__init__.py
+++ b/pytorch_widedeep/preprocessing/__init__.py
@@ -5,6 +5,7 @@
from pytorch_widedeep.preprocessing.tab_preprocessor import (
TabPreprocessor,
ChunkTabPreprocessor,
+ embed_sz_rule,
)
from pytorch_widedeep.preprocessing.text_preprocessor import (
TextPreprocessor,
diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py
new file mode 100644
index 00000000..04655a82
--- /dev/null
+++ b/pytorch_widedeep/preprocessing/din_preprocessor.py
@@ -0,0 +1,451 @@
+import re
+from typing import Dict, List, Tuple, Union, Literal, Optional
+
+import numpy as np
+import pandas as pd
+
+from pytorch_widedeep import Trainer
+from pytorch_widedeep.models import WideDeep
+from pytorch_widedeep.metrics import Accuracy
+from pytorch_widedeep.datasets import load_movielens100k
+from pytorch_widedeep.preprocessing import TabPreprocessor, embed_sz_rule
+from pytorch_widedeep.models.rec.din import DeepInterestNetwork
+from pytorch_widedeep.utils.text_utils import pad_sequences
+from pytorch_widedeep.utils.deeptabular_utils import LabelEncoder
+from pytorch_widedeep.preprocessing.base_preprocessor import (
+ BasePreprocessor,
+ check_is_fitted,
+)
+
+
+class DINPreprocessor(BasePreprocessor):
+ def __init__(
+ self,
+ *,
+ user_id_col: str,
+ target_col: str,
+ item_embed_col: Union[str, Tuple[str, int]],
+ max_seq_length: int,
+ action_col: Optional[str] = None,
+ other_seq_embed_cols: Optional[List[str] | List[Tuple[str, int]]] = None,
+ cat_embed_cols: Optional[Union[List[str], List[Tuple[str, int]]]] = None,
+ continuous_cols: Optional[List[str]] = None,
+ quantization_setup: Optional[
+ Union[int, Dict[str, Union[int, List[float]]]]
+ ] = None,
+ cols_to_scale: Optional[Union[List[str], str]] = None,
+ auto_embed_dim: bool = True,
+ embedding_rule: Literal["google", "fastai_old", "fastai_new"] = "fastai_new",
+ default_embed_dim: int = 16,
+ verbose: int = 1,
+ scale: bool = False,
+ already_standard: Optional[List[str]] = None,
+ **kwargs,
+ ):
+
+ self.user_id_col = user_id_col
+ self.item_embed_col = item_embed_col
+ self.max_seq_length = max_seq_length
+ self.target_col = target_col if target_col is not None else "target"
+ self.action_col = action_col
+ self.other_seq_embed_cols = other_seq_embed_cols
+ self.cat_embed_cols = cat_embed_cols
+ self.continuous_cols = continuous_cols
+ self.quantization_setup = quantization_setup
+ self.cols_to_scale = cols_to_scale
+ self.auto_embed_dim = auto_embed_dim
+ self.embedding_rule = embedding_rule
+ self.default_embed_dim = default_embed_dim
+ self.verbose = verbose
+ self.scale = scale
+ self.already_standard = already_standard
+ self.kwargs = kwargs
+
+ self.has_standard_tab_data = False
+ if self.cat_embed_cols or self.continuous_cols:
+ self.has_standard_tab_data = True
+ self.tab_preprocessor = TabPreprocessor(
+ cat_embed_cols=self.cat_embed_cols,
+ continuous_cols=self.continuous_cols,
+ quantization_setup=self.quantization_setup,
+ cols_to_scale=self.cols_to_scale,
+ auto_embed_dim=self.auto_embed_dim,
+ embedding_rule=self.embedding_rule,
+ default_embed_dim=self.default_embed_dim,
+ verbose=self.verbose,
+ scale=self.scale,
+ already_standard=self.already_standard,
+ **self.kwargs,
+ )
+
+ self.is_fitted = False
+
+ def fit(self, df: pd.DataFrame):
+
+ if self.has_standard_tab_data:
+ self.tab_preprocessor.fit(df)
+ self.din_columns_idx = {
+ col: i for i, col in enumerate(self.tab_preprocessor.column_idx.keys())
+ }
+ else:
+ self.din_columns_idx = {}
+
+ self.item_le = LabelEncoder(
+ columns_to_encode=[
+ (
+ self.item_embed_col[0]
+ if isinstance(self.item_embed_col, tuple)
+ else self.item_embed_col
+ )
+ ]
+ )
+ self.item_le.fit(df)
+ self.n_items = max(self.item_le.encoding_dict[self.item_embed_col[0]].values())
+ user_behaviour_embed_size = (
+ self.item_embed_col[1]
+ if isinstance(self.item_embed_col, tuple)
+ else embed_sz_rule(self.n_items)
+ )
+ self.user_behaviour_config: Tuple[List[str], int, int] = (
+ [f"item_{i+1}" for i in range(self.max_seq_length)],
+ self.n_items,
+ user_behaviour_embed_size,
+ )
+
+ _current_len = len(self.din_columns_idx)
+ self.din_columns_idx.update(
+ {f"item_{i+1}": i + _current_len for i in range(self.max_seq_length)}
+ )
+ self.din_columns_idx.update(
+ {
+ "target_item": len(self.din_columns_idx),
+ }
+ )
+
+ if self.action_col is not None:
+ self.action_le = LabelEncoder(columns_to_encode=[self.action_col])
+ self.action_le.fit(df)
+
+ self.n_actions = len(self.action_le.encoding_dict[self.action_col])
+ self.action_seq_config: Tuple[List[str], int] = (
+ [f"action_{i+1}" for i in range(self.max_seq_length)],
+ self.n_actions,
+ )
+
+ _current_len = len(self.din_columns_idx)
+ self.din_columns_idx.update(
+ {f"action_{i+1}": i + _current_len for i in range(self.max_seq_length)}
+ )
+
+ if self.other_seq_embed_cols is not None:
+ self.other_seq_le = LabelEncoder(
+ columns_to_encode=[
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ )
+ self.other_seq_le.fit(df)
+
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ self.n_other_seq_cols: Dict[str, int] = {
+ col: max(self.other_seq_le.encoding_dict[col].values())
+ for col in other_seq_cols
+ }
+
+ if isinstance(self.other_seq_embed_cols[0], tuple):
+ other_seq_embed_sizes: Dict[str, int] = {
+ tp[0]: tp[1] for tp in self.other_seq_embed_cols # type: ignore[misc]
+ }
+ else:
+ other_seq_embed_sizes = {
+ col: embed_sz_rule(self.n_other_seq_cols[col])
+ for col in other_seq_cols
+ }
+
+ self.other_seq_config: List[Tuple[List[str], int, int]] = [
+ (
+ [f"{col}_{i+1}" for i in range(self.max_seq_length)],
+ self.n_other_seq_cols[col],
+ other_seq_embed_sizes[col],
+ )
+ for col in other_seq_cols
+ ]
+
+ _current_len = len(self.din_columns_idx)
+ for col in other_seq_cols:
+ self.din_columns_idx.update(
+ {
+ f"{col}_{i+1}": i + _current_len
+ for i in range(self.max_seq_length)
+ }
+ )
+ self.din_columns_idx.update(
+ {f"target_{col}": len(self.din_columns_idx)}
+ )
+ _current_len = len(self.din_columns_idx)
+
+ self.is_fitted = True
+
+ return self
+
+ def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
+
+ _df = self._pre_transform(df)
+
+ df_w_sequences = self._build_sequences(_df)
+
+ user_behaviour_seq = df_w_sequences["items_sequence"].tolist()
+ X_user_behaviour = np.vstack(
+ [
+ pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0)
+ for seq in user_behaviour_seq
+ ]
+ )
+
+ X_target_item = np.array(df_w_sequences["target_item"].tolist()).reshape(-1, 1)
+
+ X_all = np.concatenate([X_user_behaviour, X_target_item], axis=1)
+
+ if self.has_standard_tab_data:
+ X_tab = self.tab_preprocessor.transform(df_w_sequences)
+ X_all = np.concatenate([X_tab, X_all], axis=1)
+
+ if self.action_col is not None:
+ action_seq = df_w_sequences["actions_sequence"].tolist()
+ X_actions = np.vstack(
+ [
+ pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0)
+ for seq in action_seq
+ ]
+ )
+ X_all = np.concatenate([X_all, X_actions], axis=1)
+
+ if self.other_seq_embed_cols is not None:
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ other_seq = {
+ col: df_w_sequences[f"{col}_sequence"].tolist()
+ for col in other_seq_cols
+ }
+ other_seq_target = {
+ col: df_w_sequences[f"target_{col}"].tolist() for col in other_seq_cols
+ }
+ X_other_seq_arrays: List[np.ndarray] = []
+ # to obsessively make sure that the order of the columns is the
+ # same
+ for col in other_seq_cols:
+ X_other_seq_arrays.append(
+ np.vstack(
+ [
+ pad_sequences(
+ s, self.max_seq_length, pad_first=False, pad_idx=0
+ )
+ for s in other_seq[col]
+ ]
+ )
+ )
+ X_other_seq_arrays.append(
+ np.array(other_seq_target[col]).reshape(-1, 1)
+ )
+
+ X_all = np.concatenate([X_all] + X_other_seq_arrays, axis=1)
+
+ assert len(self.din_columns_idx) == X_all.shape[1], (
+ f"Something went wrong. The number of columns in the final array "
+ f"({X_all.shape[1]}) is different from the number of columns in "
+ f"self.din_columns_idx ({len(self.din_columns_idx)})"
+ )
+
+ return X_all, np.array(df_w_sequences["target"].tolist())
+
+ def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
+ self.fit(df)
+ return self.transform(df)
+
+ def _pre_transform(self, df: pd.DataFrame) -> pd.DataFrame:
+
+ check_is_fitted(self, attributes=["din_columns_idx"])
+
+ df = self.item_le.transform(df)
+
+ if self.action_col is not None:
+ df = self.action_le.transform(df)
+
+ if self.other_seq_embed_cols is not None:
+ df = self.other_seq_le.transform(df)
+
+ return df
+
+ def _build_sequences(self, df: pd.DataFrame) -> pd.DataFrame:
+
+ res_df = (
+ df.groupby(self.user_id_col)
+ .apply(self._group_sequences)
+ .reset_index(drop=True)
+ )
+
+ return res_df
+
+ def _group_sequences(self, group: pd.DataFrame) -> pd.DataFrame:
+
+ item_col = (
+ self.item_embed_col[0]
+ if isinstance(self.item_embed_col, tuple)
+ else self.item_embed_col
+ )
+ items = group[item_col].tolist()
+
+ targets = group[self.target_col].tolist()
+ drop_cols = [item_col, self.target_col]
+
+ sequences: List[Dict[str, str | List[int]]] = []
+ for i in range(len(items) - self.max_seq_length):
+ item_sequences = items[i : i + self.max_seq_length]
+ target_item = items[i + self.max_seq_length]
+ target = targets[i + self.max_seq_length]
+
+ sequence = {
+ "user_id": group.name,
+ "items_sequence": item_sequences,
+ "target_item": target_item,
+ }
+
+ if self.action_col is not None:
+ if self.action_col != self.target_col:
+ actions = group[self.action_col].tolist()
+ action_sequences = actions[i : i + self.max_seq_length]
+ drop_cols.append(self.action_col)
+ else:
+ target -= (
+ 1 # the 'transform' method adds 1 as it saves 0 for padding
+ )
+ action_sequences = targets[i : i + self.max_seq_length]
+
+ sequence["target"] = target
+ sequence["actions_sequence"] = action_sequences
+
+ if self.other_seq_embed_cols is not None:
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ drop_cols += other_seq_cols
+ other_seqs: Dict[str, List[int]] = {
+ col: group[col].tolist() for col in other_seq_cols
+ }
+ other_seqs_sequences = {
+ col: other_seqs[col][i : i + self.max_seq_length]
+ for col in other_seq_cols
+ }
+
+ for col in other_seq_cols:
+ sequence[f"{col}_sequence"] = other_seqs_sequences[col]
+ sequence[f"target_{col}"] = other_seqs[col][i + self.max_seq_length]
+
+ sequences.append(sequence)
+
+ seq_df = pd.DataFrame(sequences)
+
+ non_seq_cols = group.drop_duplicates(["user_id"]).drop(drop_cols, axis=1)
+
+ return pd.merge(seq_df, non_seq_cols, on="user_id")
+
+ def inverse_transform(self, X: np.ndarray) -> pd.DataFrame:
+ # trasform mutates the df and is complex to revert the transformation
+ raise NotImplementedError(
+ "inverse_transform is not implemented for this preprocessor"
+ )
+
+
+if __name__ == "__main__":
+
+ def clean_genre_list(genre_list):
+ return "_".join(
+ sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list])
+ )
+
+ data, users, items = load_movielens100k(as_frame=True)
+
+ list_of_genres = [
+ "unknown",
+ "Action",
+ "Adventure",
+ "Animation",
+ "Children's",
+ "Comedy",
+ "Crime",
+ "Documentary",
+ "Drama",
+ "Fantasy",
+ "Film-Noir",
+ "Horror",
+ "Musical",
+ "Mystery",
+ "Romance",
+ "Sci-Fi",
+ "Thriller",
+ "War",
+ "Western",
+ ]
+
+ assert (
+ isinstance(items, pd.DataFrame)
+ and isinstance(data, pd.DataFrame)
+ and isinstance(users, pd.DataFrame)
+ )
+ items["genre_list"] = items[list_of_genres].apply(
+ lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1
+ )
+
+ items["genre_list"] = items["genre_list"].apply(clean_genre_list)
+
+ df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id")
+ df = pd.merge(
+ df,
+ users[["user_id", "age", "gender", "occupation"]],
+ on="user_id",
+ )
+
+ df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0)
+
+ df = df.sort_values(by=["timestamp"]).reset_index(drop=True)
+ df = df.drop("timestamp", axis=1)
+
+ din_preprocessor = DINPreprocessor(
+ user_id_col="user_id",
+ target_col="rating",
+ item_embed_col=("movie_id", 16),
+ max_seq_length=5,
+ action_col="rating",
+ other_seq_embed_cols=[("genre_list", 16)],
+ cat_embed_cols=["user_id", "age", "gender", "occupation"],
+ )
+
+ X, y = din_preprocessor.fit_transform(df)
+
+ din = DeepInterestNetwork(
+ column_idx=din_preprocessor.din_columns_idx,
+ user_behavior_confiq=din_preprocessor.user_behaviour_config,
+ action_seq_config=din_preprocessor.action_seq_config,
+ other_seq_cols_confiq=din_preprocessor.other_seq_config,
+ cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined]
+ mlp_hidden_dims=[128, 64],
+ )
+
+ # And from here on, everything is standard
+ model = WideDeep(deeptabular=din)
+
+ trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()])
+
+ # in the real world you would have to split the data into train, val and test
+ trainer.fit(
+ X_tab=X,
+ target=y,
+ n_epochs=5,
+ batch_size=512,
+ )
From 45fce4d207ecb70b339d4b9b4d04be8504af92b8 Mon Sep 17 00:00:00 2001
From: Javier
Date: Mon, 14 Oct 2024 09:55:33 +0100
Subject: [PATCH 04/24] first test added for the din preprocessor
---
pytorch_widedeep/preprocessing/__init__.py | 1 +
.../preprocessing/din_preprocessor.py | 768 ++++++++++--------
.../interactions_data.csv | 113 +++
.../generate_interactions_data.py | 194 +++++
.../test_data_utils/test_din_preprocessor.py | 87 ++
5 files changed, 842 insertions(+), 321 deletions(-)
create mode 100644 tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv
create mode 100644 tests/test_data_utils/generate_interactions_data.py
create mode 100644 tests/test_data_utils/test_din_preprocessor.py
diff --git a/pytorch_widedeep/preprocessing/__init__.py b/pytorch_widedeep/preprocessing/__init__.py
index b4c5bcee..f52a5e73 100644
--- a/pytorch_widedeep/preprocessing/__init__.py
+++ b/pytorch_widedeep/preprocessing/__init__.py
@@ -2,6 +2,7 @@
HFPreprocessor,
ChunkHFPreprocessor,
)
+from pytorch_widedeep.preprocessing.din_preprocessor import DINPreprocessor
from pytorch_widedeep.preprocessing.tab_preprocessor import (
TabPreprocessor,
ChunkTabPreprocessor,
diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py
index 04655a82..d28e6926 100644
--- a/pytorch_widedeep/preprocessing/din_preprocessor.py
+++ b/pytorch_widedeep/preprocessing/din_preprocessor.py
@@ -1,17 +1,22 @@
-import re
+# import re
+import warnings
from typing import Dict, List, Tuple, Union, Literal, Optional
import numpy as np
import pandas as pd
-from pytorch_widedeep import Trainer
-from pytorch_widedeep.models import WideDeep
-from pytorch_widedeep.metrics import Accuracy
-from pytorch_widedeep.datasets import load_movielens100k
-from pytorch_widedeep.preprocessing import TabPreprocessor, embed_sz_rule
-from pytorch_widedeep.models.rec.din import DeepInterestNetwork
+# from pytorch_widedeep.models import WideDeep
+# from pytorch_widedeep.metrics import Accuracy
+# from pytorch_widedeep.datasets import load_movielens100k
+# from pytorch_widedeep.training import Trainer
+# from pytorch_widedeep.models.rec.din import DeepInterestNetwork
from pytorch_widedeep.utils.text_utils import pad_sequences
+from pytorch_widedeep.utils.general_utils import alias
from pytorch_widedeep.utils.deeptabular_utils import LabelEncoder
+from pytorch_widedeep.preprocessing.tab_preprocessor import (
+ TabPreprocessor,
+ embed_sz_rule,
+)
from pytorch_widedeep.preprocessing.base_preprocessor import (
BasePreprocessor,
check_is_fitted,
@@ -19,15 +24,94 @@
class DINPreprocessor(BasePreprocessor):
+ """
+ Preprocessor for Deep Interest Network (DIN) models.
+
+ This preprocessor handles the preparation of data for DIN models,
+ including sequence building, label encoding, and handling of various
+ types of input columns (categorical, continuous, and sequential).
+
+ Parameters:
+ -----------
+ user_id_col : str
+ Name of the column containing user IDs.
+ item_embed_col : Union[str, Tuple[str, int]]
+ Name of the column containing item IDs to be embedded, or a tuple of
+ (column_name, embedding_dim).
+ target_col : str
+ Name of the column containing the target variable.
+ max_seq_length : int
+ Maximum length of sequences to be created.
+ action_col : Optional[str], default=None
+ Name of the column containing user actions (if applicable).
+ other_seq_embed_cols : Optional[List[str] | List[Tuple[str, int]]], default=None
+ List of other columns to be treated as sequences.
+ cat_embed_cols : Optional[Union[List[str], List[Tuple[str, int]]]], default=None
+ List of categorical columns to be represented by embeddings.
+ continuous_cols: List, default = None
+ List with the name of the continuous cols.
+ quantization_setup: int or Dict[str, Union[int, List[float]]], default=None
+ Continuous columns can be turned into categorical via `pd.cut`.
+ cols_to_scale: List or str, default = None,
+ List with the names of the columns that will be standardized via
+ sklearn's `StandardScaler`.
+ auto_embed_dim: bool, default = True
+ Boolean indicating whether the embedding dimensions will be
+ automatically defined via rule of thumb.
+ embedding_rule: str, default = 'fastai_new'
+ Rule of thumb for embedding size.
+ default_embed_dim: int, default=16
+ Default dimension for the embeddings.
+ verbose : int, default=1
+ Verbosity level.
+ scale: bool, default = False
+ Boolean indicating whether or not to scale/standardize continuous cols.
+ already_standard: List, default = None
+ List with the name of the continuous cols that do not need to be
+ scaled/standardized.
+ **kwargs :
+ Additional keyword arguments to be passed to the TabPreprocessor.
+
+ Attributes:
+ -----------
+ is_fitted : bool
+ Whether the preprocessor has been fitted.
+ has_standard_tab_data : bool
+ Whether the data includes standard tabular data.
+ tab_preprocessor : TabPreprocessor
+ Preprocessor for standard tabular data.
+ din_columns_idx : Dict[str, int]
+ Dictionary mapping column names to their indices in the processed data.
+ item_le : LabelEncoder
+ Label encoder for item IDs.
+ n_items : int
+ Number of unique items.
+ user_behaviour_config : Tuple[List[str], int, int]
+ Configuration for user behavior sequences.
+ action_le : LabelEncoder
+ Label encoder for action column (if applicable).
+ n_actions : int
+ Number of unique actions (if applicable).
+ action_seq_config : Tuple[List[str], int]
+ Configuration for action sequences (if applicable).
+ other_seq_le : LabelEncoder
+ Label encoder for other sequence columns (if applicable).
+ n_other_seq_cols : Dict[str, int]
+ Number of unique values in each other sequence column.
+ other_seq_config : List[Tuple[List[str], int, int]]
+ Configuration for other sequence columns.
+ """
+
+ @alias("item_embed_col", ["item_id_col"])
def __init__(
self,
*,
user_id_col: str,
- target_col: str,
item_embed_col: Union[str, Tuple[str, int]],
+ target_col: str,
max_seq_length: int,
action_col: Optional[str] = None,
- other_seq_embed_cols: Optional[List[str] | List[Tuple[str, int]]] = None,
+ other_seq_embed_cols: Optional[Union[List[str], List[Tuple[str, int]]]] = None,
cat_embed_cols: Optional[Union[List[str], List[Tuple[str, int]]]] = None,
continuous_cols: Optional[List[str]] = None,
quantization_setup: Optional[
@@ -42,7 +126,6 @@ def __init__(
already_standard: Optional[List[str]] = None,
**kwargs,
):
-
self.user_id_col = user_id_col
self.item_embed_col = item_embed_col
self.max_seq_length = max_seq_length
@@ -61,9 +144,8 @@ def __init__(
self.already_standard = already_standard
self.kwargs = kwargs
- self.has_standard_tab_data = False
- if self.cat_embed_cols or self.continuous_cols:
- self.has_standard_tab_data = True
+ self.has_standard_tab_data = bool(self.cat_embed_cols or self.continuous_cols)
+ if self.has_standard_tab_data:
self.tab_preprocessor = TabPreprocessor(
cat_embed_cols=self.cat_embed_cols,
continuous_cols=self.continuous_cols,
@@ -80,8 +162,7 @@ def __init__(
self.is_fitted = False
- def fit(self, df: pd.DataFrame):
-
+ def fit(self, df: pd.DataFrame) -> "DINPreprocessor":
if self.has_standard_tab_data:
self.tab_preprocessor.fit(df)
self.din_columns_idx = {
@@ -90,362 +171,407 @@ def fit(self, df: pd.DataFrame):
else:
self.din_columns_idx = {}
- self.item_le = LabelEncoder(
- columns_to_encode=[
- (
- self.item_embed_col[0]
- if isinstance(self.item_embed_col, tuple)
- else self.item_embed_col
- )
- ]
+ self._fit_label_encoders(df)
+ self.is_fitted = True
+ return self
+
+ def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
+ _df = self._pre_transform(df)
+ df_w_sequences = self._build_sequences(_df)
+ X_all = self._concatenate_features(df_w_sequences)
+ return X_all, np.array(df_w_sequences["target"].tolist())
+
+ def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
+ self.fit(df)
+ return self.transform(df)
+
+ def inverse_transform(self, X: np.ndarray) -> pd.DataFrame:
+ raise NotImplementedError(
+ "inverse_transform is not implemented for this preprocessor"
)
+
+ def _fit_label_encoders(self, df: pd.DataFrame) -> None:
+ self.item_le = LabelEncoder(columns_to_encode=[self._get_item_col()])
self.item_le.fit(df)
- self.n_items = max(self.item_le.encoding_dict[self.item_embed_col[0]].values())
- user_behaviour_embed_size = (
- self.item_embed_col[1]
- if isinstance(self.item_embed_col, tuple)
- else embed_sz_rule(self.n_items)
+ self.n_items = max(self.item_le.encoding_dict[self._get_item_col()].values())
+ user_behaviour_embed_size = self._get_embed_size(
+ self.item_embed_col, self.n_items
)
- self.user_behaviour_config: Tuple[List[str], int, int] = (
+ self.user_behaviour_config = (
[f"item_{i+1}" for i in range(self.max_seq_length)],
self.n_items,
user_behaviour_embed_size,
)
- _current_len = len(self.din_columns_idx)
- self.din_columns_idx.update(
- {f"item_{i+1}": i + _current_len for i in range(self.max_seq_length)}
- )
- self.din_columns_idx.update(
- {
- "target_item": len(self.din_columns_idx),
- }
+ self._update_din_columns_idx("item", self.max_seq_length)
+ self.din_columns_idx["target_item"] = len(self.din_columns_idx)
+
+ if self.action_col:
+ self._fit_action_label_encoder(df)
+ if self.other_seq_embed_cols:
+ self._fit_other_seq_label_encoders(df)
+
+ def _fit_action_label_encoder(self, df: pd.DataFrame) -> None:
+ self.action_le = LabelEncoder(columns_to_encode=[self.action_col])
+ self.action_le.fit(df)
+ self.n_actions = len(self.action_le.encoding_dict[self.action_col])
+ self.action_seq_config = (
+ [f"action_{i+1}" for i in range(self.max_seq_length)],
+ self.n_actions,
)
+ self._update_din_columns_idx("action", self.max_seq_length)
+
+ def _other_seq_cols_float_warning(self, df: pd.DataFrame) -> None:
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ for col in other_seq_cols:
+ if df[col].dtype == float:
+ warnings.warn(
+ f"{col} is a float column. It will be converted to integers. "
+ "If this is not what you want, please convert it beforehand.",
+ UserWarning,
+ )
- if self.action_col is not None:
- self.action_le = LabelEncoder(columns_to_encode=[self.action_col])
- self.action_le.fit(df)
-
- self.n_actions = len(self.action_le.encoding_dict[self.action_col])
- self.action_seq_config: Tuple[List[str], int] = (
- [f"action_{i+1}" for i in range(self.max_seq_length)],
- self.n_actions,
- )
-
- _current_len = len(self.din_columns_idx)
- self.din_columns_idx.update(
- {f"action_{i+1}": i + _current_len for i in range(self.max_seq_length)}
- )
-
- if self.other_seq_embed_cols is not None:
- self.other_seq_le = LabelEncoder(
- columns_to_encode=[
- col[0] if isinstance(col, tuple) else col
- for col in self.other_seq_embed_cols
- ]
- )
- self.other_seq_le.fit(df)
-
- other_seq_cols = [
+ def _convert_other_seq_cols_to_int(self, df: pd.DataFrame) -> None:
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ for col in other_seq_cols:
+ if df[col].dtype == float:
+ df[col] = df[col].astype(int)
+
+ def _fit_other_seq_label_encoders(self, df: pd.DataFrame) -> None:
+ self._other_seq_cols_float_warning(df)
+ self._convert_other_seq_cols_to_int(df)
+ self.other_seq_le = LabelEncoder(
+ columns_to_encode=[
col[0] if isinstance(col, tuple) else col
for col in self.other_seq_embed_cols
]
- self.n_other_seq_cols: Dict[str, int] = {
- col: max(self.other_seq_le.encoding_dict[col].values())
- for col in other_seq_cols
- }
-
- if isinstance(self.other_seq_embed_cols[0], tuple):
- other_seq_embed_sizes: Dict[str, int] = {
- tp[0]: tp[1] for tp in self.other_seq_embed_cols # type: ignore[misc]
- }
- else:
- other_seq_embed_sizes = {
- col: embed_sz_rule(self.n_other_seq_cols[col])
- for col in other_seq_cols
- }
-
- self.other_seq_config: List[Tuple[List[str], int, int]] = [
- (
- [f"{col}_{i+1}" for i in range(self.max_seq_length)],
- self.n_other_seq_cols[col],
- other_seq_embed_sizes[col],
- )
- for col in other_seq_cols
- ]
-
- _current_len = len(self.din_columns_idx)
- for col in other_seq_cols:
- self.din_columns_idx.update(
- {
- f"{col}_{i+1}": i + _current_len
- for i in range(self.max_seq_length)
- }
- )
- self.din_columns_idx.update(
- {f"target_{col}": len(self.din_columns_idx)}
- )
- _current_len = len(self.din_columns_idx)
-
- self.is_fitted = True
-
- return self
-
- def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
-
- _df = self._pre_transform(df)
-
- df_w_sequences = self._build_sequences(_df)
-
- user_behaviour_seq = df_w_sequences["items_sequence"].tolist()
- X_user_behaviour = np.vstack(
- [
- pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0)
- for seq in user_behaviour_seq
- ]
)
-
- X_target_item = np.array(df_w_sequences["target_item"].tolist()).reshape(-1, 1)
-
- X_all = np.concatenate([X_user_behaviour, X_target_item], axis=1)
-
- if self.has_standard_tab_data:
- X_tab = self.tab_preprocessor.transform(df_w_sequences)
- X_all = np.concatenate([X_tab, X_all], axis=1)
-
- if self.action_col is not None:
- action_seq = df_w_sequences["actions_sequence"].tolist()
- X_actions = np.vstack(
- [
- pad_sequences(seq, self.max_seq_length, pad_first=False, pad_idx=0)
- for seq in action_seq
- ]
+ self.other_seq_le.fit(df)
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ self.n_other_seq_cols = {
+ col: max(self.other_seq_le.encoding_dict[col].values())
+ for col in other_seq_cols
+ }
+ other_seq_embed_sizes = {
+ col: self._get_embed_size(col, self.n_other_seq_cols[col])
+ for col in other_seq_cols
+ }
+ self.other_seq_config = [
+ (
+ self._get_seq_col_names(col),
+ self.n_other_seq_cols[col],
+ other_seq_embed_sizes[col],
)
- X_all = np.concatenate([X_all, X_actions], axis=1)
+ for col in other_seq_cols
+ ]
+ self._update_din_columns_idx_for_other_seq(other_seq_cols)
- if self.other_seq_embed_cols is not None:
- other_seq_cols = [
- col[0] if isinstance(col, tuple) else col
- for col in self.other_seq_embed_cols
- ]
- other_seq = {
- col: df_w_sequences[f"{col}_sequence"].tolist()
- for col in other_seq_cols
- }
- other_seq_target = {
- col: df_w_sequences[f"target_{col}"].tolist() for col in other_seq_cols
- }
- X_other_seq_arrays: List[np.ndarray] = []
- # to obsessively make sure that the order of the columns is the
- # same
- for col in other_seq_cols:
- X_other_seq_arrays.append(
- np.vstack(
- [
- pad_sequences(
- s, self.max_seq_length, pad_first=False, pad_idx=0
- )
- for s in other_seq[col]
- ]
- )
- )
- X_other_seq_arrays.append(
- np.array(other_seq_target[col]).reshape(-1, 1)
- )
+ def _get_item_col(self) -> str:
+ return (
+ self.item_embed_col[0]
+ if isinstance(self.item_embed_col, tuple)
+ else self.item_embed_col
+ )
- X_all = np.concatenate([X_all] + X_other_seq_arrays, axis=1)
+ def _get_embed_size(self, col: Union[str, Tuple[str, int]], n_unique: int) -> int:
+ return col[1] if isinstance(col, tuple) else embed_sz_rule(n_unique)
- assert len(self.din_columns_idx) == X_all.shape[1], (
- f"Something went wrong. The number of columns in the final array "
- f"({X_all.shape[1]}) is different from the number of columns in "
- f"self.din_columns_idx ({len(self.din_columns_idx)})"
- )
+ def _get_seq_col_names(self, col: str) -> List[str]:
+ return [f"{col}_{i+1}" for i in range(self.max_seq_length)]
- return X_all, np.array(df_w_sequences["target"].tolist())
+ def _update_din_columns_idx(self, prefix: str, length: int) -> None:
+ current_len = len(self.din_columns_idx)
+ self.din_columns_idx.update(
+ {f"{prefix}_{i+1}": i + current_len for i in range(length)}
+ )
- def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
- self.fit(df)
- return self.transform(df)
+ def _update_din_columns_idx_for_other_seq(self, other_seq_cols: List[str]) -> None:
+ current_len = len(self.din_columns_idx)
+ for col in other_seq_cols:
+ self.din_columns_idx.update(
+ {f"{col}_{i+1}": i + current_len for i in range(self.max_seq_length)}
+ )
+ self.din_columns_idx[f"target_{col}"] = len(self.din_columns_idx)
+ current_len = len(self.din_columns_idx)
def _pre_transform(self, df: pd.DataFrame) -> pd.DataFrame:
-
check_is_fitted(self, attributes=["din_columns_idx"])
-
df = self.item_le.transform(df)
-
- if self.action_col is not None:
+ if self.action_col:
df = self.action_le.transform(df)
-
- if self.other_seq_embed_cols is not None:
+ if self.other_seq_embed_cols:
df = self.other_seq_le.transform(df)
-
return df
def _build_sequences(self, df: pd.DataFrame) -> pd.DataFrame:
-
- res_df = (
+ # TO DO: do something to avoid warning here
+ return (
df.groupby(self.user_id_col)
.apply(self._group_sequences)
.reset_index(drop=True)
)
- return res_df
-
def _group_sequences(self, group: pd.DataFrame) -> pd.DataFrame:
-
- item_col = (
- self.item_embed_col[0]
- if isinstance(self.item_embed_col, tuple)
- else self.item_embed_col
- )
+ item_col = self._get_item_col()
items = group[item_col].tolist()
-
targets = group[self.target_col].tolist()
drop_cols = [item_col, self.target_col]
- sequences: List[Dict[str, str | List[int]]] = []
- for i in range(len(items) - self.max_seq_length):
- item_sequences = items[i : i + self.max_seq_length]
- target_item = items[i + self.max_seq_length]
- target = targets[i + self.max_seq_length]
+ # sequences cannot be built with user, item pairs with only one
+ # interaction
+ if len(items) <= 1:
+ return pd.DataFrame()
- sequence = {
- "user_id": group.name,
- "items_sequence": item_sequences,
- "target_item": target_item,
- }
+ sequences = [
+ self._create_sequence(group, items, targets, i, drop_cols)
+ for i in range(max(1, len(items) - self.max_seq_length))
+ ]
- if self.action_col is not None:
- if self.action_col != self.target_col:
- actions = group[self.action_col].tolist()
- action_sequences = actions[i : i + self.max_seq_length]
- drop_cols.append(self.action_col)
- else:
- target -= (
- 1 # the 'transform' method adds 1 as it saves 0 for padding
- )
- action_sequences = targets[i : i + self.max_seq_length]
-
- sequence["target"] = target
- sequence["actions_sequence"] = action_sequences
-
- if self.other_seq_embed_cols is not None:
- other_seq_cols = [
- col[0] if isinstance(col, tuple) else col
- for col in self.other_seq_embed_cols
- ]
- drop_cols += other_seq_cols
- other_seqs: Dict[str, List[int]] = {
- col: group[col].tolist() for col in other_seq_cols
- }
- other_seqs_sequences = {
- col: other_seqs[col][i : i + self.max_seq_length]
- for col in other_seq_cols
- }
+ seq_df = pd.DataFrame(sequences)
+ non_seq_cols = group.drop_duplicates([self.user_id_col]).drop(drop_cols, axis=1)
+ return pd.merge(seq_df, non_seq_cols, on=self.user_id_col)
- for col in other_seq_cols:
- sequence[f"{col}_sequence"] = other_seqs_sequences[col]
- sequence[f"target_{col}"] = other_seqs[col][i + self.max_seq_length]
+ def _create_sequence(
+ self,
+ group: pd.DataFrame,
+ items: List[int],
+ targets: List[int],
+ i: int,
+ drop_cols: List[str],
+ ) -> Dict[str, Union[str, List[int]]]:
+ end_idx = min(i + self.max_seq_length, len(items) - 1)
+ item_sequences = items[i:end_idx]
+ target_item = items[end_idx]
+ target = targets[end_idx]
+
+ sequence = {
+ self.user_id_col: group.name,
+ "items_sequence": item_sequences,
+ "target_item": target_item,
+ }
+
+ if self.action_col:
+ sequence.update(
+ self._create_action_sequence(group, targets, i, target, drop_cols)
+ )
+ if self.other_seq_embed_cols:
+ sequence.update(self._create_other_seq_sequence(group, i, drop_cols))
- sequences.append(sequence)
+ return sequence
- seq_df = pd.DataFrame(sequences)
+ def _create_action_sequence(
+ self,
+ group: pd.DataFrame,
+ targets: List[int],
+ i: int,
+ target: int,
+ drop_cols: List[str],
+ ) -> Dict[str, Union[int, List[int]]]:
+ if self.action_col != self.target_col:
+ actions = group[self.action_col].tolist()
+ action_sequences = (
+ actions[i : i + self.max_seq_length]
+ if i + self.max_seq_length < len(actions)
+ else actions[i:-1]
+ )
+ drop_cols.append(self.action_col)
+ else:
+ target -= 1 # the 'transform' method adds 1 as it saves 0 for padding
+ action_sequences = (
+ targets[i : i + self.max_seq_length]
+ if i + self.max_seq_length < len(targets)
+ else targets[i:-1]
+ )
+
+ return {"target": target, "actions_sequence": action_sequences}
+
+ def _create_other_seq_sequence(
+ self, group: pd.DataFrame, i: int, drop_cols: List[str]
+ ) -> Dict[str, Union[int, List[int]]]:
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ drop_cols += other_seq_cols
+ other_seqs = {col: group[col].tolist() for col in other_seq_cols}
+
+ other_seqs_sequences: Dict[str, Union[int, List[int]]] = {}
+ for col in other_seq_cols:
+ other_seqs_sequences[col] = (
+ other_seqs[col][i : i + self.max_seq_length]
+ if i + self.max_seq_length < len(other_seqs[col])
+ else other_seqs[col][i:-1]
+ )
- non_seq_cols = group.drop_duplicates(["user_id"]).drop(drop_cols, axis=1)
+ sequence: Dict[str, Union[int, List[int]]] = {}
+ for col in other_seq_cols:
+ sequence[f"{col}_sequence"] = other_seqs_sequences[col]
+ sequence[f"target_{col}"] = (
+ other_seqs[col][i + self.max_seq_length]
+ if i + self.max_seq_length < len(other_seqs[col])
+ else other_seqs[col][-1]
+ )
- return pd.merge(seq_df, non_seq_cols, on="user_id")
+ return sequence
- def inverse_transform(self, X: np.ndarray) -> pd.DataFrame:
- # trasform mutates the df and is complex to revert the transformation
- raise NotImplementedError(
- "inverse_transform is not implemented for this preprocessor"
+ def _concatenate_features(self, df_w_sequences: pd.DataFrame) -> np.ndarray:
+ user_behaviour_seq = df_w_sequences["items_sequence"].tolist()
+ X_user_behaviour = np.vstack(
+ [
+ pad_sequences(seq, self.max_seq_length, pad_idx=0)
+ for seq in user_behaviour_seq
+ ]
)
+ X_target_item = np.array(df_w_sequences["target_item"].tolist()).reshape(-1, 1)
+ X_all = np.concatenate([X_user_behaviour, X_target_item], axis=1)
+ if self.has_standard_tab_data:
+ X_tab = self.tab_preprocessor.transform(df_w_sequences)
+ X_all = np.concatenate([X_tab, X_all], axis=1)
-if __name__ == "__main__":
+ if self.action_col:
+ action_seq = df_w_sequences["actions_sequence"].tolist()
+ X_actions = np.vstack(
+ [
+ pad_sequences(seq, self.max_seq_length, pad_idx=0)
+ for seq in action_seq
+ ]
+ )
+ X_all = np.concatenate([X_all, X_actions], axis=1)
+
+ if self.other_seq_embed_cols:
+ X_all = self._concatenate_other_seq_features(df_w_sequences, X_all)
- def clean_genre_list(genre_list):
- return "_".join(
- sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list])
+ assert len(self.din_columns_idx) == X_all.shape[1], (
+ f"Something went wrong. The number of columns in the final array "
+ f"({X_all.shape[1]}) is different from the number of columns in "
+ f"self.din_columns_idx ({len(self.din_columns_idx)})"
)
- data, users, items = load_movielens100k(as_frame=True)
-
- list_of_genres = [
- "unknown",
- "Action",
- "Adventure",
- "Animation",
- "Children's",
- "Comedy",
- "Crime",
- "Documentary",
- "Drama",
- "Fantasy",
- "Film-Noir",
- "Horror",
- "Musical",
- "Mystery",
- "Romance",
- "Sci-Fi",
- "Thriller",
- "War",
- "Western",
- ]
-
- assert (
- isinstance(items, pd.DataFrame)
- and isinstance(data, pd.DataFrame)
- and isinstance(users, pd.DataFrame)
- )
- items["genre_list"] = items[list_of_genres].apply(
- lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1
- )
-
- items["genre_list"] = items["genre_list"].apply(clean_genre_list)
-
- df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id")
- df = pd.merge(
- df,
- users[["user_id", "age", "gender", "occupation"]],
- on="user_id",
- )
-
- df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0)
-
- df = df.sort_values(by=["timestamp"]).reset_index(drop=True)
- df = df.drop("timestamp", axis=1)
-
- din_preprocessor = DINPreprocessor(
- user_id_col="user_id",
- target_col="rating",
- item_embed_col=("movie_id", 16),
- max_seq_length=5,
- action_col="rating",
- other_seq_embed_cols=[("genre_list", 16)],
- cat_embed_cols=["user_id", "age", "gender", "occupation"],
- )
-
- X, y = din_preprocessor.fit_transform(df)
-
- din = DeepInterestNetwork(
- column_idx=din_preprocessor.din_columns_idx,
- user_behavior_confiq=din_preprocessor.user_behaviour_config,
- action_seq_config=din_preprocessor.action_seq_config,
- other_seq_cols_confiq=din_preprocessor.other_seq_config,
- cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined]
- mlp_hidden_dims=[128, 64],
- )
-
- # And from here on, everything is standard
- model = WideDeep(deeptabular=din)
-
- trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()])
-
- # in the real world you would have to split the data into train, val and test
- trainer.fit(
- X_tab=X,
- target=y,
- n_epochs=5,
- batch_size=512,
- )
+ return X_all
+
+ def _concatenate_other_seq_features(
+ self, df_w_sequences: pd.DataFrame, X_all: np.ndarray
+ ) -> np.ndarray:
+ other_seq_cols = [
+ col[0] if isinstance(col, tuple) else col
+ for col in self.other_seq_embed_cols
+ ]
+ other_seq = {
+ col: df_w_sequences[f"{col}_sequence"].tolist() for col in other_seq_cols
+ }
+ other_seq_target = {
+ col: df_w_sequences[f"target_{col}"].tolist() for col in other_seq_cols
+ }
+ X_other_seq_arrays = []
+
+ for col in other_seq_cols:
+ X_other_seq_arrays.append(
+ np.vstack(
+ [
+ pad_sequences(s, self.max_seq_length, pad_idx=0)
+ for s in other_seq[col]
+ ]
+ )
+ )
+ X_other_seq_arrays.append(np.array(other_seq_target[col]).reshape(-1, 1))
+
+ return np.concatenate([X_all] + X_other_seq_arrays, axis=1)
+
+
+# if __name__ == "__main__":
+
+# def clean_genre_list(genre_list):
+# return "_".join(
+# sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list])
+# )
+
+# data, users, items = load_movielens100k(as_frame=True)
+
+# list_of_genres = [
+# "unknown",
+# "Action",
+# "Adventure",
+# "Animation",
+# "Children's",
+# "Comedy",
+# "Crime",
+# "Documentary",
+# "Drama",
+# "Fantasy",
+# "Film-Noir",
+# "Horror",
+# "Musical",
+# "Mystery",
+# "Romance",
+# "Sci-Fi",
+# "Thriller",
+# "War",
+# "Western",
+# ]
+
+# assert (
+# isinstance(items, pd.DataFrame)
+# and isinstance(data, pd.DataFrame)
+# and isinstance(users, pd.DataFrame)
+# )
+# items["genre_list"] = items[list_of_genres].apply(
+# lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1
+# )
+
+# items["genre_list"] = items["genre_list"].apply(clean_genre_list)
+
+# df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id")
+# df = pd.merge(
+# df,
+# users[["user_id", "age", "gender", "occupation"]],
+# on="user_id",
+# )
+
+# df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0)
+
+# df = df.sort_values(by=["timestamp"]).reset_index(drop=True)
+# df = df.drop("timestamp", axis=1)
+
+# din_preprocessor = DINPreprocessor(
+# user_id_col="user_id",
+# target_col="rating",
+# item_embed_col=("movie_id", 16),
+# max_seq_length=5,
+# action_col="rating",
+# other_seq_embed_cols=[("genre_list", 16)],
+# cat_embed_cols=["user_id", "age", "gender", "occupation"],
+# )
+
+# X, y = din_preprocessor.fit_transform(df)
+
+# din = DeepInterestNetwork(
+# column_idx=din_preprocessor.din_columns_idx,
+# user_behavior_confiq=din_preprocessor.user_behaviour_config,
+# action_seq_config=din_preprocessor.action_seq_config,
+# other_seq_cols_confiq=din_preprocessor.other_seq_config,
+# cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined]
+# mlp_hidden_dims=[128, 64],
+# )
+
+# # And from here on, everything is standard
+# model = WideDeep(deeptabular=din)
+
+# trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()])
+
+# # in the real world you would have to split the data into train, val and test
+# trainer.fit(
+# X_tab=X,
+# target=y,
+# n_epochs=5,
+# batch_size=512,
+# )
diff --git a/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv b/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv
new file mode 100644
index 00000000..e7587df7
--- /dev/null
+++ b/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv
@@ -0,0 +1,113 @@
+user_id,item_id,age,gender,height,weight,price,category,timestamp,interaction
+5,2,25,M,182.5,99.6,65.05,C,2023-01-02,1
+3,1,32,M,166.7,96.9,65.57,B,2023-01-02,0
+5,6,25,M,182.5,99.6,45.99,A,2023-01-04,0
+3,7,32,M,166.7,96.9,14.2,C,2023-01-05,1
+3,8,32,M,166.7,96.9,97.64,C,2023-01-09,0
+4,2,60,F,157.1,50.0,65.05,C,2023-01-13,1
+2,6,46,M,173.0,86.1,45.99,A,2023-01-15,1
+5,1,25,M,182.5,99.6,65.57,B,2023-01-17,0
+5,7,25,M,182.5,99.6,14.2,C,2023-01-22,0
+5,6,25,M,182.5,99.6,45.99,A,2023-01-27,0
+4,2,60,F,157.1,50.0,65.05,C,2023-01-28,1
+4,8,60,F,157.1,50.0,97.64,C,2023-01-29,1
+3,2,32,M,166.7,96.9,65.05,C,2023-02-02,1
+3,6,32,M,166.7,96.9,45.99,A,2023-02-04,1
+3,9,32,M,166.7,96.9,30.95,A,2023-02-06,0
+2,7,46,M,173.0,86.1,14.2,C,2023-02-10,1
+3,4,32,M,166.7,96.9,12.08,A,2023-02-11,1
+4,2,60,F,157.1,50.0,65.05,C,2023-02-13,1
+4,10,60,F,157.1,50.0,18.15,A,2023-02-14,1
+5,10,25,M,182.5,99.6,18.15,A,2023-02-20,0
+5,2,25,M,182.5,99.6,65.05,C,2023-02-21,1
+1,2,56,M,155.0,52.8,65.05,C,2023-02-22,0
+3,1,32,M,166.7,96.9,65.57,B,2023-03-01,0
+2,6,46,M,173.0,86.1,45.99,A,2023-03-06,1
+4,2,60,F,157.1,50.0,65.05,C,2023-03-07,1
+2,5,46,M,173.0,86.1,57.23,C,2023-03-22,1
+2,9,46,M,173.0,86.1,30.95,A,2023-04-08,0
+3,5,32,M,166.7,96.9,57.23,C,2023-04-09,1
+3,1,32,M,166.7,96.9,65.57,B,2023-04-09,0
+3,3,32,M,166.7,96.9,10.64,C,2023-04-09,0
+4,1,60,F,157.1,50.0,65.57,B,2023-04-09,0
+4,5,60,F,157.1,50.0,57.23,C,2023-04-14,0
+4,4,60,F,157.1,50.0,12.08,A,2023-04-22,0
+3,3,32,M,166.7,96.9,10.64,C,2023-04-22,0
+4,4,60,F,157.1,50.0,12.08,A,2023-04-23,0
+5,9,25,M,182.5,99.6,30.95,A,2023-04-23,1
+5,3,25,M,182.5,99.6,10.64,C,2023-04-23,1
+3,1,32,M,166.7,96.9,65.57,B,2023-04-26,0
+1,1,56,M,155.0,52.8,65.57,B,2023-05-06,0
+5,6,25,M,182.5,99.6,45.99,A,2023-05-08,0
+2,8,46,M,173.0,86.1,97.64,C,2023-05-08,0
+3,5,32,M,166.7,96.9,57.23,C,2023-05-11,1
+4,6,60,F,157.1,50.0,45.99,A,2023-05-14,1
+5,7,25,M,182.5,99.6,14.2,C,2023-05-15,0
+3,2,32,M,166.7,96.9,65.05,C,2023-05-15,1
+2,3,46,M,173.0,86.1,10.64,C,2023-05-16,1
+5,8,25,M,182.5,99.6,97.64,C,2023-05-17,1
+4,9,60,F,157.1,50.0,30.95,A,2023-05-23,1
+3,9,32,M,166.7,96.9,30.95,A,2023-05-24,0
+2,4,46,M,173.0,86.1,12.08,A,2023-05-27,0
+2,8,46,M,173.0,86.1,97.64,C,2023-05-31,0
+3,8,32,M,166.7,96.9,97.64,C,2023-06-01,0
+5,5,25,M,182.5,99.6,57.23,C,2023-06-01,1
+5,8,25,M,182.5,99.6,97.64,C,2023-06-09,1
+5,7,25,M,182.5,99.6,14.2,C,2023-06-10,0
+4,7,60,F,157.1,50.0,14.2,C,2023-06-18,0
+5,1,25,M,182.5,99.6,65.57,B,2023-06-25,0
+5,2,25,M,182.5,99.6,65.05,C,2023-06-29,1
+5,2,25,M,182.5,99.6,65.05,C,2023-07-06,1
+2,8,46,M,173.0,86.1,97.64,C,2023-07-06,0
+2,7,46,M,173.0,86.1,14.2,C,2023-07-07,1
+2,7,46,M,173.0,86.1,14.2,C,2023-07-09,1
+5,6,25,M,182.5,99.6,45.99,A,2023-07-11,0
+2,1,46,M,173.0,86.1,65.57,B,2023-07-20,0
+2,10,46,M,173.0,86.1,18.15,A,2023-07-21,1
+2,8,46,M,173.0,86.1,97.64,C,2023-07-22,0
+1,10,56,M,155.0,52.8,18.15,A,2023-07-25,1
+3,6,32,M,166.7,96.9,45.99,A,2023-07-26,1
+4,4,60,F,157.1,50.0,12.08,A,2023-07-27,0
+2,6,46,M,173.0,86.1,45.99,A,2023-07-27,1
+2,10,46,M,173.0,86.1,18.15,A,2023-08-05,1
+1,4,56,M,155.0,52.8,12.08,A,2023-08-06,0
+5,3,25,M,182.5,99.6,10.64,C,2023-08-08,1
+5,7,25,M,182.5,99.6,14.2,C,2023-08-08,0
+5,9,25,M,182.5,99.6,30.95,A,2023-08-12,1
+5,9,25,M,182.5,99.6,30.95,A,2023-08-13,1
+3,7,32,M,166.7,96.9,14.2,C,2023-08-15,1
+5,8,25,M,182.5,99.6,97.64,C,2023-08-22,1
+2,4,46,M,173.0,86.1,12.08,A,2023-08-28,0
+4,8,60,F,157.1,50.0,97.64,C,2023-08-31,1
+2,8,46,M,173.0,86.1,97.64,C,2023-09-09,0
+4,5,60,F,157.1,50.0,57.23,C,2023-09-11,0
+3,4,32,M,166.7,96.9,12.08,A,2023-09-12,1
+2,4,46,M,173.0,86.1,12.08,A,2023-09-12,0
+5,2,25,M,182.5,99.6,65.05,C,2023-09-16,1
+2,9,46,M,173.0,86.1,30.95,A,2023-09-16,0
+1,5,56,M,155.0,52.8,57.23,C,2023-09-21,1
+2,1,46,M,173.0,86.1,65.57,B,2023-09-24,0
+2,7,46,M,173.0,86.1,14.2,C,2023-09-27,1
+5,4,25,M,182.5,99.6,12.08,A,2023-10-04,0
+2,6,46,M,173.0,86.1,45.99,A,2023-10-07,1
+5,1,25,M,182.5,99.6,65.57,B,2023-10-11,0
+4,2,60,F,157.1,50.0,65.05,C,2023-10-13,1
+5,7,25,M,182.5,99.6,14.2,C,2023-10-16,0
+5,5,25,M,182.5,99.6,57.23,C,2023-10-22,1
+3,8,32,M,166.7,96.9,97.64,C,2023-10-31,0
+3,1,32,M,166.7,96.9,65.57,B,2023-10-31,0
+5,10,25,M,182.5,99.6,18.15,A,2023-10-31,0
+4,3,60,F,157.1,50.0,10.64,C,2023-11-06,0
+5,1,25,M,182.5,99.6,65.57,B,2023-11-13,0
+5,2,25,M,182.5,99.6,65.05,C,2023-11-14,1
+2,3,46,M,173.0,86.1,10.64,C,2023-11-23,1
+3,7,32,M,166.7,96.9,14.2,C,2023-11-24,1
+4,3,60,F,157.1,50.0,10.64,C,2023-12-04,0
+2,6,46,M,173.0,86.1,45.99,A,2023-12-04,1
+1,9,56,M,155.0,52.8,30.95,A,2023-12-06,1
+2,3,46,M,173.0,86.1,10.64,C,2023-12-10,1
+2,4,46,M,173.0,86.1,12.08,A,2023-12-12,0
+4,1,60,F,157.1,50.0,65.57,B,2023-12-15,0
+2,8,46,M,173.0,86.1,97.64,C,2023-12-18,0
+3,2,32,M,166.7,96.9,65.05,C,2023-12-26,1
+2,2,46,M,173.0,86.1,65.05,C,2023-12-26,0
diff --git a/tests/test_data_utils/generate_interactions_data.py b/tests/test_data_utils/generate_interactions_data.py
new file mode 100644
index 00000000..df128e37
--- /dev/null
+++ b/tests/test_data_utils/generate_interactions_data.py
@@ -0,0 +1,194 @@
+import os
+from pathlib import Path
+from datetime import datetime, timedelta
+
+import numpy as np
+import pandas as pd
+
+full_path = os.path.realpath(__file__)
+path = os.path.split(full_path)[0]
+
+save_dir = Path(path) / "data_for_rec_preprocessor"
+
+
+def generate_sample_data(n_users=5, n_items=10, seed=42, return_df=False):
+ """
+ Generate a sample dataset for recommendation system testing.
+
+ Parameters:
+ -----------
+ - n_users: int, number of users to generate (default 5)
+ - n_items: int, number of items to generate (default 10)
+ - seed: int, random seed for reproducibility (default 42)
+ """
+ np.random.seed(seed)
+
+ # Generate user data
+ users = pd.DataFrame(
+ {
+ "user_id": range(1, n_users + 1),
+ "age": np.random.randint(18, 65, n_users),
+ "gender": np.random.choice(["M", "F"], n_users),
+ "height": np.random.uniform(150, 200, n_users).round(1),
+ "weight": np.random.uniform(50, 100, n_users).round(1),
+ }
+ )
+
+ # Generate item data
+ items = pd.DataFrame(
+ {
+ "item_id": range(1, n_items + 1),
+ "price": np.random.uniform(10, 100, n_items).round(2),
+ "category": np.random.choice(["A", "B", "C"], n_items),
+ }
+ )
+
+ # Generate positive interactions
+ positive_interactions = []
+ start_date = datetime(2023, 1, 1)
+
+ for user_id in users["user_id"]:
+ # Randomly select 5 items for this user to interact with
+ user_items = np.random.choice(items["item_id"], 5, replace=False)
+
+ if user_id == 1:
+ n_interactions = 3
+ else:
+ n_interactions = np.random.randint(4, 21)
+
+ for _ in range(n_interactions):
+ item_id = np.random.choice(user_items)
+ timestamp = start_date + timedelta(days=np.random.randint(0, 365))
+ positive_interactions.append(
+ {
+ "user_id": user_id,
+ "item_id": item_id,
+ "timestamp": timestamp,
+ "interaction": 1, # Positive interaction
+ }
+ )
+
+ positive_df = pd.DataFrame(positive_interactions)
+
+ # Generate negative interactions
+ negative_interactions = []
+ for user_id in users["user_id"]:
+ positive_items = positive_df[positive_df["user_id"] == user_id][
+ "item_id"
+ ].unique()
+ negative_items = items[~items["item_id"].isin(positive_items)]["item_id"]
+
+ n_negative = len(positive_df[positive_df["user_id"] == user_id])
+
+ for _ in range(n_negative):
+ item_id = np.random.choice(negative_items)
+ timestamp = start_date + timedelta(days=np.random.randint(0, 365))
+ negative_interactions.append(
+ {
+ "user_id": user_id,
+ "item_id": item_id,
+ "timestamp": timestamp,
+ "interaction": 0, # Negative interaction
+ }
+ )
+
+ negative_df = pd.DataFrame(negative_interactions)
+
+ # Combine positive and negative interactions
+ interactions_df = pd.concat([positive_df, negative_df], ignore_index=True)
+
+ # Merge all data
+ final_df = interactions_df.merge(users, on="user_id").merge(items, on="item_id")
+
+ # Sort by timestamp
+ final_df = final_df.sort_values("timestamp").reset_index(drop=True)
+
+ # Reorder columns
+ column_order = [
+ "user_id",
+ "item_id",
+ "age",
+ "gender",
+ "height",
+ "weight",
+ "price",
+ "category",
+ "timestamp",
+ "interaction",
+ ]
+ final_df = final_df[column_order]
+
+ if not save_dir.exists():
+ save_dir.mkdir(parents=True)
+
+ if not return_df:
+ final_df.to_csv(save_dir / "interactions_data.csv", index=False)
+ else:
+ return final_df
+
+
+def split_by_timestamp(df, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
+ """
+ Split the dataframe into train, validation, and test sets based on timestamp.
+
+ Parameters:
+ - df: pd.DataFrame, the input dataframe
+ - train_ratio: float, ratio of data for training (default 0.8)
+ - val_ratio: float, ratio of data for validation (default 0.1)
+ - test_ratio: float, ratio of data for testing (default 0.1)
+
+ Returns:
+ - train_df, val_df, test_df: pd.DataFrame, the split dataframes
+ """
+ assert np.isclose(train_ratio + val_ratio + test_ratio, 1.0), "Ratios must sum to 1"
+
+ df_sorted = df.sort_values("timestamp")
+ total_rows = len(df_sorted)
+ train_rows = int(total_rows * train_ratio)
+ val_rows = int(total_rows * val_ratio)
+
+ train_df = df_sorted.iloc[:train_rows]
+ val_df = df_sorted.iloc[train_rows : train_rows + val_rows]
+ test_df = df_sorted.iloc[train_rows + val_rows :]
+
+ return train_df, val_df, test_df
+
+
+def split_by_last_interactions(df):
+ """
+ Split the dataframe into train, validation, and test sets based on the last interactions.
+ Train: all interactions but last two
+ Validation: second to last interaction
+ Test: last interaction
+
+ Parameters:
+ - df: pd.DataFrame, the input dataframe
+
+ Returns:
+ - train_df, val_df, test_df: pd.DataFrame, the split dataframes
+ """
+ df_sorted = df.sort_values(["user_id", "timestamp"])
+
+ # Get the last two interactions for each user
+ last_two = df_sorted.groupby("user_id").tail(2)
+
+ # Split the last two interactions into validation and test
+ test_df = last_two.groupby("user_id").last().reset_index()
+ val_df = last_two.groupby("user_id").first().reset_index()
+
+ # Remove the last two interactions from the original dataframe to create the train set
+ train_df = df_sorted[~df_sorted.index.isin(last_two.index)].reset_index(drop=True)
+
+ return train_df, val_df, test_df
+
+
+if __name__ == "__main__":
+ generate_sample_data()
+
+ # df = generate_sample_data(return_df=True)
+
+ # # Split the data by timestamp
+ # train_df, val_df, test_df = split_by_timestamp(df)
+
+ # # Split the data by last interactions
+ # train_df2, val_df2, test_df2 = split_by_last_interactions(df)
diff --git a/tests/test_data_utils/test_din_preprocessor.py b/tests/test_data_utils/test_din_preprocessor.py
new file mode 100644
index 00000000..36523b77
--- /dev/null
+++ b/tests/test_data_utils/test_din_preprocessor.py
@@ -0,0 +1,87 @@
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from pytorch_widedeep.preprocessing import DINPreprocessor
+
+full_path = os.path.realpath(__file__)
+path = os.path.split(full_path)[0]
+
+save_dir = Path(path) / "data_for_rec_preprocessor"
+
+df_interactions = pd.read_csv(save_dir / "interactions_data.csv")
+
+df_interactions = df_interactions.sort_values(by=["user_id", "timestamp"]).reset_index(
+ drop=True
+)
+
+cat_embed_cols = ["user_id", "age", "gender"]
+continuous_cols = ["height", "weight"]
+
+
+def user_behaviour_well_encoded(
+ input_df: pd.DataFrame,
+ din_preprocessor: DINPreprocessor,
+ X: np.ndarray,
+ max_seq_length: int,
+):
+
+ user_id: int = 5
+
+ encoded_user_id = ( # type: ignore
+ din_preprocessor.tab_preprocessor.label_encoder.encoding_dict["user_id"][
+ user_id
+ ]
+ )
+
+ user_items = (
+ input_df[input_df["user_id"] == user_id]
+ .groupby("user_id")["item_id"]
+ .agg(list)
+ .values[0]
+ )[:max_seq_length]
+
+ encoded_items = [
+ din_preprocessor.item_le.encoding_dict["item_id"][item] for item in user_items
+ ]
+
+ rows = np.where(
+ X[:, din_preprocessor.din_columns_idx["user_id"]] == encoded_user_id
+ )[0]
+ X_user_id = X[rows][:1]
+
+ item_seq_cols = din_preprocessor.user_behaviour_config[0]
+ item_seq_cols_idx = [din_preprocessor.din_columns_idx[col] for col in item_seq_cols]
+
+ X_user_items = list(X_user_id[:, item_seq_cols_idx].astype(int)[0])
+
+ return X_user_items == encoded_items
+
+
+def test_din_preprocessor():
+
+ din_preprocessor = DINPreprocessor(
+ user_id_col="user_id",
+ item_embed_col="item_id",
+ target_col="interaction",
+ action_col="interaction",
+ other_seq_embed_cols=["category", "price"],
+ cat_embed_cols=cat_embed_cols,
+ continuous_cols=continuous_cols,
+ cols_to_scale=continuous_cols,
+ max_seq_length=5,
+ )
+
+ X, y = din_preprocessor.fit_transform(df_interactions)
+ # 5 items + 5 actions + 2 * 5 other_seq_embed_cols + (1 target item + 1
+ # target category + 1 target price) + 5 continuous_cols and
+ # cat_embed_cols
+ expected_n_cols = 5 + 5 + (2 * 5) + 3 + 5
+ assert X.shape[1] == expected_n_cols
+ assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5)
+
+
+if __name__ == "__main__":
+ test_din_preprocessor()
From ecc6ea0532cf9a1cb97abee134f1b72cf8921d74 Mon Sep 17 00:00:00 2001
From: Javier
Date: Mon, 14 Oct 2024 20:34:07 +0100
Subject: [PATCH 05/24] ready to test the whole thing. I would have to add an
extra text combining the preprocessor and the model
---
.../preprocessing/din_preprocessor.py | 8 +-
.../generate_interactions_data.py | 2 +-
.../interactions_data.csv | 0
.../test_data_utils/test_din_preprocessor.py | 87 ----------
.../test_du_din_preprocessor.py | 162 ++++++++++++++++++
5 files changed, 169 insertions(+), 90 deletions(-)
rename tests/test_data_utils/{data_for_rec_preprocessor => interactions_df_for_rec_preprocessor}/interactions_data.csv (100%)
delete mode 100644 tests/test_data_utils/test_din_preprocessor.py
create mode 100644 tests/test_data_utils/test_du_din_preprocessor.py
diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py
index d28e6926..5819875f 100644
--- a/pytorch_widedeep/preprocessing/din_preprocessor.py
+++ b/pytorch_widedeep/preprocessing/din_preprocessor.py
@@ -179,7 +179,7 @@ def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
_df = self._pre_transform(df)
df_w_sequences = self._build_sequences(_df)
X_all = self._concatenate_features(df_w_sequences)
- return X_all, np.array(df_w_sequences["target"].tolist())
+ return X_all, np.array(df_w_sequences[self.target_col].tolist())
def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
self.fit(df)
@@ -363,6 +363,10 @@ def _create_sequence(
sequence.update(
self._create_action_sequence(group, targets, i, target, drop_cols)
)
+ else:
+ sequence[self.target_col] = target
+ drop_cols.append(self.target_col)
+
if self.other_seq_embed_cols:
sequence.update(self._create_other_seq_sequence(group, i, drop_cols))
@@ -392,7 +396,7 @@ def _create_action_sequence(
else targets[i:-1]
)
- return {"target": target, "actions_sequence": action_sequences}
+ return {self.target_col: target, "actions_sequence": action_sequences}
def _create_other_seq_sequence(
self, group: pd.DataFrame, i: int, drop_cols: List[str]
diff --git a/tests/test_data_utils/generate_interactions_data.py b/tests/test_data_utils/generate_interactions_data.py
index df128e37..7df32497 100644
--- a/tests/test_data_utils/generate_interactions_data.py
+++ b/tests/test_data_utils/generate_interactions_data.py
@@ -8,7 +8,7 @@
full_path = os.path.realpath(__file__)
path = os.path.split(full_path)[0]
-save_dir = Path(path) / "data_for_rec_preprocessor"
+save_dir = Path(path) / "interactions_df_for_rec_preprocessor"
def generate_sample_data(n_users=5, n_items=10, seed=42, return_df=False):
diff --git a/tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv b/tests/test_data_utils/interactions_df_for_rec_preprocessor/interactions_data.csv
similarity index 100%
rename from tests/test_data_utils/data_for_rec_preprocessor/interactions_data.csv
rename to tests/test_data_utils/interactions_df_for_rec_preprocessor/interactions_data.csv
diff --git a/tests/test_data_utils/test_din_preprocessor.py b/tests/test_data_utils/test_din_preprocessor.py
deleted file mode 100644
index 36523b77..00000000
--- a/tests/test_data_utils/test_din_preprocessor.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import os
-from pathlib import Path
-
-import numpy as np
-import pandas as pd
-
-from pytorch_widedeep.preprocessing import DINPreprocessor
-
-full_path = os.path.realpath(__file__)
-path = os.path.split(full_path)[0]
-
-save_dir = Path(path) / "data_for_rec_preprocessor"
-
-df_interactions = pd.read_csv(save_dir / "interactions_data.csv")
-
-df_interactions = df_interactions.sort_values(by=["user_id", "timestamp"]).reset_index(
- drop=True
-)
-
-cat_embed_cols = ["user_id", "age", "gender"]
-continuous_cols = ["height", "weight"]
-
-
-def user_behaviour_well_encoded(
- input_df: pd.DataFrame,
- din_preprocessor: DINPreprocessor,
- X: np.ndarray,
- max_seq_length: int,
-):
-
- user_id: int = 5
-
- encoded_user_id = ( # type: ignore
- din_preprocessor.tab_preprocessor.label_encoder.encoding_dict["user_id"][
- user_id
- ]
- )
-
- user_items = (
- input_df[input_df["user_id"] == user_id]
- .groupby("user_id")["item_id"]
- .agg(list)
- .values[0]
- )[:max_seq_length]
-
- encoded_items = [
- din_preprocessor.item_le.encoding_dict["item_id"][item] for item in user_items
- ]
-
- rows = np.where(
- X[:, din_preprocessor.din_columns_idx["user_id"]] == encoded_user_id
- )[0]
- X_user_id = X[rows][:1]
-
- item_seq_cols = din_preprocessor.user_behaviour_config[0]
- item_seq_cols_idx = [din_preprocessor.din_columns_idx[col] for col in item_seq_cols]
-
- X_user_items = list(X_user_id[:, item_seq_cols_idx].astype(int)[0])
-
- return X_user_items == encoded_items
-
-
-def test_din_preprocessor():
-
- din_preprocessor = DINPreprocessor(
- user_id_col="user_id",
- item_embed_col="item_id",
- target_col="interaction",
- action_col="interaction",
- other_seq_embed_cols=["category", "price"],
- cat_embed_cols=cat_embed_cols,
- continuous_cols=continuous_cols,
- cols_to_scale=continuous_cols,
- max_seq_length=5,
- )
-
- X, y = din_preprocessor.fit_transform(df_interactions)
- # 5 items + 5 actions + 2 * 5 other_seq_embed_cols + (1 target item + 1
- # target category + 1 target price) + 5 continuous_cols and
- # cat_embed_cols
- expected_n_cols = 5 + 5 + (2 * 5) + 3 + 5
- assert X.shape[1] == expected_n_cols
- assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5)
-
-
-if __name__ == "__main__":
- test_din_preprocessor()
diff --git a/tests/test_data_utils/test_du_din_preprocessor.py b/tests/test_data_utils/test_du_din_preprocessor.py
new file mode 100644
index 00000000..40ea8dfd
--- /dev/null
+++ b/tests/test_data_utils/test_du_din_preprocessor.py
@@ -0,0 +1,162 @@
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from pytorch_widedeep.preprocessing import DINPreprocessor
+
+full_path = os.path.realpath(__file__)
+path = os.path.split(full_path)[0]
+
+save_dir = Path(path) / "interactions_df_for_rec_preprocessor"
+
+df_interactions = pd.read_csv(save_dir / "interactions_data.csv")
+
+df_interactions = df_interactions.sort_values(by=["user_id", "timestamp"]).reset_index(
+ drop=True
+)
+
+cat_embed_cols = ["user_id", "age", "gender"]
+continuous_cols = ["height", "weight"]
+
+
+def user_behaviour_well_encoded(
+ input_df: pd.DataFrame,
+ din_preprocessor: DINPreprocessor,
+ X: np.ndarray,
+ max_seq_length: int,
+):
+
+ user_id: int = 5
+
+ encoded_user_id = ( # type: ignore
+ din_preprocessor.tab_preprocessor.label_encoder.encoding_dict["user_id"][
+ user_id
+ ]
+ )
+
+ user_items = (
+ input_df[input_df["user_id"] == user_id]
+ .groupby("user_id")["item_id"]
+ .agg(list)
+ .values[0]
+ )[:max_seq_length]
+
+ encoded_items = [
+ din_preprocessor.item_le.encoding_dict["item_id"][item] for item in user_items
+ ]
+
+ rows = np.where(
+ X[:, din_preprocessor.din_columns_idx["user_id"]] == encoded_user_id
+ )[0]
+ X_user_id = X[rows][:1]
+
+ item_seq_cols = din_preprocessor.user_behaviour_config[0]
+ item_seq_cols_idx = [din_preprocessor.din_columns_idx[col] for col in item_seq_cols]
+
+ X_user_items = list(X_user_id[:, item_seq_cols_idx].astype(int)[0])
+
+ return X_user_items == encoded_items
+
+
+def test_din_preprocessor():
+
+ din_preprocessor = DINPreprocessor(
+ user_id_col="user_id",
+ item_embed_col="item_id",
+ target_col="interaction",
+ action_col="interaction",
+ other_seq_embed_cols=["category", "price"],
+ cat_embed_cols=cat_embed_cols,
+ continuous_cols=continuous_cols,
+ cols_to_scale=continuous_cols,
+ max_seq_length=5,
+ )
+
+ X, y = din_preprocessor.fit_transform(df_interactions)
+ # 5 items + 5 actions + 2 * 5 other_seq_embed_cols + (1 target item + 1
+ # target category + 1 target price) + 5 continuous_cols and
+ # cat_embed_cols
+ expected_n_cols = 5 + 5 + (2 * 5) + 3 + 5
+ assert X.shape[1] == expected_n_cols
+ assert din_preprocessor.is_fitted
+ assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5)
+
+
+@pytest.mark.parametrize(
+ "with_cat_embed_cols, with_continuous_cols, with_action_col, with_other_seq_embed_cols",
+ [
+ (True, True, True, True),
+ (True, True, True, False),
+ (True, True, False, True),
+ (True, True, False, False),
+ (True, False, True, True),
+ (True, False, True, False),
+ (True, False, False, True),
+ (True, False, False, False),
+ (False, True, True, True),
+ (False, True, True, False),
+ (False, True, False, True),
+ (False, True, False, False),
+ (False, False, True, True),
+ (False, False, True, False),
+ (False, False, False, True),
+ (False, False, False, False),
+ ],
+)
+def test_din_preprocessor_diff_input_params(
+ with_cat_embed_cols,
+ with_continuous_cols,
+ with_action_col,
+ with_other_seq_embed_cols,
+):
+
+ max_seq_length = 5
+
+ din_preprocessor = DINPreprocessor(
+ user_id_col="user_id",
+ item_embed_col="item_id",
+ target_col="interaction",
+ action_col="interaction" if with_action_col else None,
+ other_seq_embed_cols=(
+ ["category", "price"] if with_other_seq_embed_cols else None
+ ),
+ cat_embed_cols=cat_embed_cols if with_cat_embed_cols else None,
+ continuous_cols=continuous_cols if with_continuous_cols else None,
+ cols_to_scale=continuous_cols if with_continuous_cols else None,
+ max_seq_length=max_seq_length,
+ )
+
+ X, y = din_preprocessor.fit_transform(df_interactions)
+
+ expected_n_cols = max_seq_length + 1
+
+ assert din_preprocessor.user_behaviour_config[0] == [
+ f"item_{i+1}" for i in range(max_seq_length)
+ ]
+
+ if with_action_col:
+ expected_n_cols += 5
+ assert din_preprocessor.action_seq_config[0] == [
+ f"action_{i+1}" for i in range(max_seq_length)
+ ]
+
+ if with_other_seq_embed_cols:
+ for i, col in enumerate(["category", "price"]):
+ expected_n_cols += 5 + 1
+ assert din_preprocessor.other_seq_config[i][0] == [
+ f"{col}_{i+1}" for i in range(max_seq_length)
+ ]
+
+ if with_cat_embed_cols:
+ expected_n_cols += len(cat_embed_cols)
+
+ if with_continuous_cols:
+ expected_n_cols += len(continuous_cols)
+
+ assert din_preprocessor.is_fitted
+ assert X.shape[1] == expected_n_cols
+ if with_cat_embed_cols:
+ assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5)
From 112d8c317b804fe74c041de3521a8cc11bf2a627 Mon Sep 17 00:00:00 2001
From: Javier
Date: Mon, 14 Oct 2024 21:34:50 +0100
Subject: [PATCH 06/24] Added test for the full process. Ready to test the
whole library
---
pytorch_widedeep/models/rec/din.py | 3 +-
.../test_du_din_preprocessor.py | 42 +++++++++++++++++++
tests/test_rec/test_din.py | 2 +-
3 files changed, 45 insertions(+), 2 deletions(-)
diff --git a/pytorch_widedeep/models/rec/din.py b/pytorch_widedeep/models/rec/din.py
index 293f249e..571717a1 100644
--- a/pytorch_widedeep/models/rec/din.py
+++ b/pytorch_widedeep/models/rec/din.py
@@ -405,7 +405,8 @@ def forward(self, X: Tensor) -> Tensor:
[
self.other_seq_cols_embed[col]._get_embeddings(X_other_seq[col])
for col in self.other_seq_cols_indexes.keys()
- ]
+ ],
+ dim=-1,
).sum(1)
deep_out = torch.cat([deep_out, other_seq_embed], dim=1)
diff --git a/tests/test_data_utils/test_du_din_preprocessor.py b/tests/test_data_utils/test_du_din_preprocessor.py
index 40ea8dfd..b56cd04d 100644
--- a/tests/test_data_utils/test_du_din_preprocessor.py
+++ b/tests/test_data_utils/test_du_din_preprocessor.py
@@ -5,6 +5,9 @@
import pandas as pd
import pytest
+from pytorch_widedeep import Trainer
+from pytorch_widedeep.models import WideDeep
+from pytorch_widedeep.models.rec import DeepInterestNetwork
from pytorch_widedeep.preprocessing import DINPreprocessor
full_path = os.path.realpath(__file__)
@@ -160,3 +163,42 @@ def test_din_preprocessor_diff_input_params(
assert X.shape[1] == expected_n_cols
if with_cat_embed_cols:
assert user_behaviour_well_encoded(df_interactions, din_preprocessor, X, 5)
+
+
+def test_din_full_process_w_processor():
+
+ din_preprocessor = DINPreprocessor(
+ user_id_col="user_id",
+ item_embed_col="item_id",
+ target_col="interaction",
+ action_col="interaction",
+ other_seq_embed_cols=["category", "price"],
+ cat_embed_cols=cat_embed_cols,
+ continuous_cols=continuous_cols,
+ cols_to_scale=continuous_cols,
+ max_seq_length=5,
+ )
+
+ X, y = din_preprocessor.fit_transform(df_interactions)
+
+ din = DeepInterestNetwork(
+ column_idx=din_preprocessor.din_columns_idx,
+ user_behavior_confiq=din_preprocessor.user_behaviour_config,
+ target_item_col="target_item",
+ action_seq_config=din_preprocessor.action_seq_config,
+ other_seq_cols_confiq=din_preprocessor.other_seq_config,
+ cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore
+ continuous_cols=din_preprocessor.tab_preprocessor.continuous_cols,
+ mlp_hidden_dims=[16, 8],
+ )
+
+ model = WideDeep(deeptabular=din)
+
+ trainer = Trainer(model, objective="binary", verbose=0)
+
+ trainer.fit(X_tab=X, target=y, n_epochs=2)
+
+ preds = trainer.predict(X_tab=X)
+
+ assert preds.shape[0] == X.shape[0]
+ assert trainer.history is not None and "train_loss" in trainer.history
diff --git a/tests/test_rec/test_din.py b/tests/test_rec/test_din.py
index 92724791..dcdbde86 100644
--- a/tests/test_rec/test_din.py
+++ b/tests/test_rec/test_din.py
@@ -79,8 +79,8 @@ def test_din_with_params(
din = DeepInterestNetwork(
column_idx=column_idx,
- target_item_col="target_item",
user_behavior_confiq=item_seq_config,
+ target_item_col="target_item",
action_seq_config=item_purchased_seq_config,
other_seq_cols_confiq=item_category_seq_config,
cat_embed_input=cat_embed_input,
From adc77d7bf87fe80b94e4041058b29d3264c533ae Mon Sep 17 00:00:00 2001
From: Javier
Date: Tue, 15 Oct 2024 10:42:08 +0100
Subject: [PATCH 07/24] All test passed. Need to increase coverage a bit
---
.../sources/pytorch-widedeep/preprocessing.md | 2 +
pytorch_widedeep/models/rec/din.py | 7 +-
.../preprocessing/din_preprocessor.py | 124 ++++--------------
pytorch_widedeep/tab2vec.py | 2 +-
4 files changed, 34 insertions(+), 101 deletions(-)
diff --git a/mkdocs/sources/pytorch-widedeep/preprocessing.md b/mkdocs/sources/pytorch-widedeep/preprocessing.md
index ab608846..520dd914 100644
--- a/mkdocs/sources/pytorch-widedeep/preprocessing.md
+++ b/mkdocs/sources/pytorch-widedeep/preprocessing.md
@@ -21,6 +21,8 @@ available: one for the case when no Hugging Face model is used
::: pytorch_widedeep.preprocessing.image_preprocessor.ImagePreprocessor
+::: pytorch_widedeep.preprocessing.din_preprocessor.DINPreprocessor
+
## Chunked versions
diff --git a/pytorch_widedeep/models/rec/din.py b/pytorch_widedeep/models/rec/din.py
index 571717a1..3c42a975 100644
--- a/pytorch_widedeep/models/rec/din.py
+++ b/pytorch_widedeep/models/rec/din.py
@@ -23,10 +23,9 @@ class DeepInterestNetwork(BaseWDModelComponent):
sequential columns and will be treated as standard tabular data.
This model requires some specific data preparation that allows for quite a
- lot of flexibility. Therefore, this library does not currently include a
- preprocessor designed specifically for this model. Please, see the
- example 'movielens_din.py' in the examples folder to understand the data
- preparation process.
+ lot of flexibility. Therefore, I have included a preprocessor
+ (`DINPreprocessor`) in the preprocessing module that will take care of
+ the data preparation.
Parameters
----------
diff --git a/pytorch_widedeep/preprocessing/din_preprocessor.py b/pytorch_widedeep/preprocessing/din_preprocessor.py
index 5819875f..f66bbdf7 100644
--- a/pytorch_widedeep/preprocessing/din_preprocessor.py
+++ b/pytorch_widedeep/preprocessing/din_preprocessor.py
@@ -5,11 +5,6 @@
import numpy as np
import pandas as pd
-# from pytorch_widedeep.models import WideDeep
-# from pytorch_widedeep.metrics import Accuracy
-# from pytorch_widedeep.datasets import load_movielens100k
-# from pytorch_widedeep.training import Trainer
-# from pytorch_widedeep.models.rec.din import DeepInterestNetwork
from pytorch_widedeep.utils.text_utils import pad_sequences
from pytorch_widedeep.utils.general_utils import alias
from pytorch_widedeep.utils.deeptabular_utils import LabelEncoder
@@ -100,6 +95,32 @@ class DINPreprocessor(BasePreprocessor):
Number of unique values in each other sequence column.
other_seq_config : List[Tuple[List[str], int, int]]
Configuration for other sequence columns.
+
+ Examples:
+ ---------
+ >>> import pandas as pd
+ >>> from pytorch_widedeep.preprocessing import DINPreprocessor
+ >>> data = {
+ ... 'user_id': [1, 1, 1, 2, 2, 2, 3, 3, 3],
+ ... 'item_id': [101, 102, 103, 101, 103, 104, 102, 103, 104],
+ ... 'timestamp': [1, 2, 3, 1, 2, 3, 1, 2, 3],
+ ... 'category': ['A', 'B', 'A', 'B', 'A', 'C', 'B', 'A', 'C'],
+ ... 'price': [10.5, 15.0, 12.0, 10.5, 12.0, 20.0, 15.0, 12.0, 20.0],
+ ... 'rating': [0, 1, 0, 1, 0, 1, 0, 1, 0]
+ ... }
+ >>> df = pd.DataFrame(data)
+ >>> din_preprocessor = DINPreprocessor(
+ ... user_id_col='user_id',
+ ... item_embed_col='item_id',
+ ... target_col='rating',
+ ... max_seq_length=2,
+ ... action_col='rating',
+ ... other_seq_embed_cols=['category'],
+ ... cat_embed_cols=['user_id'],
+ ... continuous_cols=['price'],
+ ... cols_to_scale=['price']
+ ... )
+ >>> X, y = din_preprocessor.fit_transform(df)
"""
@alias("item_embed_col", ["item_id_col"])
@@ -313,9 +334,9 @@ def _pre_transform(self, df: pd.DataFrame) -> pd.DataFrame:
return df
def _build_sequences(self, df: pd.DataFrame) -> pd.DataFrame:
- # TO DO: do something to avoid warning here
+ # df.columns.tolist() -> to avoid annoying pandas warning
return (
- df.groupby(self.user_id_col)
+ df.groupby(self.user_id_col)[df.columns.tolist()] # type: ignore[index]
.apply(self._group_sequences)
.reset_index(drop=True)
)
@@ -490,92 +511,3 @@ def _concatenate_other_seq_features(
X_other_seq_arrays.append(np.array(other_seq_target[col]).reshape(-1, 1))
return np.concatenate([X_all] + X_other_seq_arrays, axis=1)
-
-
-# if __name__ == "__main__":
-
-# def clean_genre_list(genre_list):
-# return "_".join(
-# sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list])
-# )
-
-# data, users, items = load_movielens100k(as_frame=True)
-
-# list_of_genres = [
-# "unknown",
-# "Action",
-# "Adventure",
-# "Animation",
-# "Children's",
-# "Comedy",
-# "Crime",
-# "Documentary",
-# "Drama",
-# "Fantasy",
-# "Film-Noir",
-# "Horror",
-# "Musical",
-# "Mystery",
-# "Romance",
-# "Sci-Fi",
-# "Thriller",
-# "War",
-# "Western",
-# ]
-
-# assert (
-# isinstance(items, pd.DataFrame)
-# and isinstance(data, pd.DataFrame)
-# and isinstance(users, pd.DataFrame)
-# )
-# items["genre_list"] = items[list_of_genres].apply(
-# lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1
-# )
-
-# items["genre_list"] = items["genre_list"].apply(clean_genre_list)
-
-# df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id")
-# df = pd.merge(
-# df,
-# users[["user_id", "age", "gender", "occupation"]],
-# on="user_id",
-# )
-
-# df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0)
-
-# df = df.sort_values(by=["timestamp"]).reset_index(drop=True)
-# df = df.drop("timestamp", axis=1)
-
-# din_preprocessor = DINPreprocessor(
-# user_id_col="user_id",
-# target_col="rating",
-# item_embed_col=("movie_id", 16),
-# max_seq_length=5,
-# action_col="rating",
-# other_seq_embed_cols=[("genre_list", 16)],
-# cat_embed_cols=["user_id", "age", "gender", "occupation"],
-# )
-
-# X, y = din_preprocessor.fit_transform(df)
-
-# din = DeepInterestNetwork(
-# column_idx=din_preprocessor.din_columns_idx,
-# user_behavior_confiq=din_preprocessor.user_behaviour_config,
-# action_seq_config=din_preprocessor.action_seq_config,
-# other_seq_cols_confiq=din_preprocessor.other_seq_config,
-# cat_embed_input=din_preprocessor.tab_preprocessor.cat_embed_input, # type: ignore[attr-defined]
-# mlp_hidden_dims=[128, 64],
-# )
-
-# # And from here on, everything is standard
-# model = WideDeep(deeptabular=din)
-
-# trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()])
-
-# # in the real world you would have to split the data into train, val and test
-# trainer.fit(
-# X_tab=X,
-# target=y,
-# n_epochs=5,
-# batch_size=512,
-# )
diff --git a/pytorch_widedeep/tab2vec.py b/pytorch_widedeep/tab2vec.py
index 5ba2ab48..c10ddb56 100644
--- a/pytorch_widedeep/tab2vec.py
+++ b/pytorch_widedeep/tab2vec.py
@@ -77,7 +77,7 @@ class Tab2Vec:
>>> # ...train the model...
>>>
>>> # vectorise the dataframe
- >>> t2v = Tab2Vec(tab_preprocessor, model)
+ >>> t2v = Tab2Vec(tab_preprocessor, model, device="cpu")
>>> X_vec = t2v.transform(df_t2v)
"""
From f076867711b34d065b250fb8223adde29c0aa35a Mon Sep 17 00:00:00 2001
From: Javier
Date: Fri, 18 Oct 2024 23:02:58 +0100
Subject: [PATCH 08/24] Added tests for the new metrics. Still need to test the
full functioning with the Trainer
---
pytorch_widedeep/metrics.py | 417 ++++++++++++++++++++-
pytorch_widedeep/training/_base_trainer.py | 33 +-
pytorch_widedeep/training/trainer.py | 45 ++-
tests/test_metrics/test_ranking_metrics.py | 188 ++++++++++
4 files changed, 661 insertions(+), 22 deletions(-)
create mode 100644 tests/test_metrics/test_ranking_metrics.py
diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py
index a2947380..e4b871ed 100644
--- a/pytorch_widedeep/metrics.py
+++ b/pytorch_widedeep/metrics.py
@@ -2,7 +2,8 @@
import torch
from torchmetrics import Metric as TorchMetric
-from pytorch_widedeep.wdtypes import Dict, List, Union, Tensor
+from pytorch_widedeep.wdtypes import Dict, List, Union, Tensor, Optional
+from pytorch_widedeep.utils.general_utils import alias
class Metric(object):
@@ -394,3 +395,417 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
y_true_avg = self.y_true_sum / self.num_examples
self.denominator += ((y_true - y_true_avg) ** 2).sum().item()
return np.array((1 - (self.numerator / self.denominator)))
+
+
+def reshape_1d_to_2d(tensor: Tensor, n_columns: int) -> Tensor:
+ if tensor.dim() != 1:
+ raise ValueError("Input tensor must be 1-dimensional")
+ if tensor.size(0) % n_columns != 0:
+ raise ValueError(
+ f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})"
+ )
+ n_rows = tensor.size(0) // n_columns
+ return tensor.reshape(n_rows, n_columns)
+
+
+class NDCG_at_k(Metric):
+ r"""
+ Normalized Discounted Cumulative Gain (NDCG) at k.
+
+ Parameters
+ ----------
+ n_cols: int, default = 10
+ Number of columns in the input tensors. This parameter is neccessary
+ because the input tensors are reshaped to 2D tensors. n_cols is the
+ number of columns in the reshaped tensor. Alias for this parameter
+ are: 'n_items', 'n_items_per_query',
+ 'n_items_per_id', 'n_items_per_user'
+ k: int, Optional, default = None
+ Number of top items to consider. It must be less than or equal to n_cols.
+ If is None, k will be equal to n_cols.
+ eps: float, default = 1e-8
+ Small value to avoid division by zero.
+
+ Examples
+ --------
+ >>> import torch
+ >>> from pytorch_widedeep.metrics import NDCG_at_k
+ >>>
+ >>> ndcg = NDCG_at_k(k=10)
+ >>> y_pred = torch.rand(100, 10)
+ >>> y_true = torch.randint(100, 10)
+ >>> score = ndcg(y_pred, y_true)
+ """
+
+ @alias(
+ "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"]
+ )
+ def __init__(self, n_cols: int = 10, k: Optional[int] = None, eps: float = 1e-8):
+ super(NDCG_at_k, self).__init__()
+
+ if k is not None and k > n_cols:
+ raise ValueError(
+ f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}"
+ )
+
+ self.n_cols = n_cols
+ self.k = k if k is not None else n_cols
+ self.eps = eps
+ self._name = f"ndcg@{k}"
+ self.reset()
+
+ def reset(self):
+ self.sum_ndcg = 0.0
+ self.count = 0
+
+ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+ device = y_pred.device
+
+ y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+
+ batch_size = y_true_reshaped.shape[0]
+
+ _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
+ top_k_relevance = y_true_reshaped.gather(1, top_k_indices)
+ discounts = 1.0 / torch.log2(
+ torch.arange(2, top_k_relevance.shape[1] + 2, device=device)
+ )
+
+ dcg = (torch.pow(2, top_k_relevance) - 1) * discounts.unsqueeze(0)
+ dcg = dcg.sum(dim=1)
+
+ sorted_relevance, _ = torch.sort(y_true_reshaped, dim=1, descending=True)
+ ideal_relevance = sorted_relevance[:, : self.k]
+
+ idcg = (torch.pow(2, ideal_relevance) - 1) * discounts[
+ : ideal_relevance.shape[1]
+ ].unsqueeze(0)
+
+ idcg = idcg.sum(dim=1)
+ ndcg = dcg / (idcg + self.eps)
+
+ self.sum_ndcg += ndcg.sum().item()
+ self.count += batch_size
+
+ return np.array(self.sum_ndcg / max(self.count, 1))
+
+
+class BinaryNDCG_at_k(Metric):
+ r"""
+ Normalized Discounted Cumulative Gain (NDCG) at k for binary relevance.
+
+ Parameters
+ ----------
+ n_cols: int, default = 10
+ Number of columns in the input tensors. This parameter is neccessary
+ because the input tensors are reshaped to 2D tensors. n_cols is the
+ number of columns in the reshaped tensor. Alias for this parameter
+ are: 'n_items', 'n_items_per_query',
+ 'n_items_per_id', 'n_items_per_user'
+ k: int, Optional, default = None
+ Number of top items to consider. It must be less than or equal to n_cols.
+ If is None, k will be equal to n_cols.
+ eps: float, default = 1e-8
+ Small value to avoid division by zero.
+
+ Examples
+ --------
+ >>> import torch
+ >>> from pytorch_widedeep.metrics import BinaryNDCG_at_k
+ >>>
+ >>> ndcg = BinaryNDCG_at_k(k=10)
+ >>> y_pred = torch.rand(100, 10)
+ >>> y_true = torch.randint(0, 2, (100, 10))
+ >>> score = ndcg(y_pred, y_true)
+ """
+
+ @alias(
+ "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"]
+ )
+ def __init__(self, n_cols: int = 10, k: Optional[int] = None, eps: float = 1e-8):
+ super(BinaryNDCG_at_k, self).__init__()
+
+ if k is not None and k > n_cols:
+ raise ValueError(
+ f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}"
+ )
+
+ self.n_cols = n_cols
+ self.k = k if k is not None else n_cols
+ self.eps = eps
+ self._name = f"binary_ndcg@{k}"
+ self.reset()
+
+ def reset(self):
+ self.sum_ndcg = 0.0
+ self.count = 0
+
+ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+ device = y_pred.device
+
+ y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+
+ batch_size = y_pred_reshaped.shape[0]
+
+ _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
+ top_k_mask = torch.zeros_like(y_pred_reshaped, dtype=torch.bool).scatter_(
+ 1, top_k_indices, 1
+ )
+
+ _discounts = 1.0 / torch.log2(
+ torch.arange(2, self.k + 2, device=device).float()
+ )
+ expanded_discounts = _discounts.repeat(1, batch_size)
+ discounts = torch.zeros_like(top_k_mask, dtype=torch.float)
+ discounts[top_k_mask] = expanded_discounts
+
+ dcg = (y_true_reshaped * top_k_mask * discounts).sum(dim=1)
+ n_relevant = torch.minimum(
+ y_true_reshaped.sum(dim=1), torch.tensor(self.k, device=device)
+ ).int()
+ ideal_discounts = 1.0 / torch.log2(
+ torch.arange(2, y_true_reshaped.shape[1] + 2, device=device).float()
+ )
+
+ idcg = torch.zeros(batch_size, device=device)
+ for i in range(batch_size):
+ idcg[i] = ideal_discounts[: n_relevant[i]].sum()
+
+ ndcg = dcg / (idcg + self.eps)
+
+ self.sum_ndcg += ndcg.sum().item()
+ self.count += batch_size
+
+ return np.array(self.sum_ndcg / max(self.count, 1))
+
+
+class MAP_at_k(Metric):
+ r"""
+ Mean Average Precision (MAP) at k.
+
+ Parameters
+ ----------
+ n_cols: int, default = 10
+ Number of columns in the input tensors. This parameter is neccessary
+ because the input tensors are reshaped to 2D tensors. n_cols is the
+ number of columns in the reshaped tensor. Alias for this parameter
+ are: 'n_items', 'n_items_per_query', 'n_items_per_id', 'n_items_per_user'
+ k: int, Optional, default = None
+ Number of top items to consider. It must be less than or equal to n_cols.
+ If is None, k will be equal to n_cols.
+
+ Examples
+ --------
+ >>> import torch
+ >>> from pytorch_widedeep.metrics import MAP_at_k
+ >>>
+ >>> map_at_k = MAP_at_k(k=10)
+ >>> y_pred = torch.rand(100, 10)
+ >>> y_true = torch.randint(0, 2, (100, 10))
+ >>> score = map_at_k(y_pred, y_true)
+ """
+
+ def __init__(self, n_cols: int = 10, k: Optional[int] = None):
+ super(MAP_at_k, self).__init__()
+
+ if k is not None and k > n_cols:
+ raise ValueError(
+ f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}"
+ )
+
+ self.n_cols = n_cols
+ self.k = k if k is not None else n_cols
+ self._name = f"map@{k}"
+ self.reset()
+
+ def reset(self):
+ self.sum_avg_precision = 0.0
+ self.count = 0
+
+ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+
+ y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+
+ batch_size = y_pred_reshaped.shape[0]
+ _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
+ batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ cumsum_relevance = torch.cumsum(batch_relevance, dim=1)
+ precision_at_i = cumsum_relevance / torch.arange(
+ 1, self.k + 1, device=y_pred_reshaped.device
+ ).float().unsqueeze(0)
+ avg_precision = (precision_at_i * batch_relevance).sum(dim=1) / torch.clamp(
+ y_true_reshaped.sum(dim=1), min=1
+ )
+ self.sum_avg_precision += avg_precision.sum().item()
+ self.count += batch_size
+ return np.array(self.sum_avg_precision / max(self.count, 1))
+
+
+class HitRatio_at_k(Metric):
+ r"""
+ Hit Ratio (HR) at k.
+
+ Parameters
+ ----------
+ n_cols: int, default = 10
+ Number of columns in the input tensors. This parameter is neccessary
+ because the input tensors are reshaped to 2D tensors. n_cols is the
+ number of columns in the reshaped tensor. Alias for this parameter
+ are: 'n_items', 'n_items_per_query', 'n_items_per_id', 'n_items_per_user'
+ k: int, Optional, default = None
+ Number of top items to consider. It must be less than or equal to n_cols.
+ If is None, k will be equal to n_cols.
+
+ Examples
+ --------
+ >>> import torch
+ >>> from pytorch_widedeep.metrics import HitRatio_at_k
+ >>>
+ >>> hr_at_k = HitRatio_at_k(k=10)
+ >>> y_pred = torch.rand(100, 10)
+ >>> y_true = torch.randint(0, 2, (100, 10))
+ >>> score = hr_at_k(y_pred, y_true
+ """
+
+ def __init__(self, n_cols: int = 10, k: Optional[int] = None):
+ super(HitRatio_at_k, self).__init__()
+
+ if k is not None and k > n_cols:
+ raise ValueError(
+ f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}"
+ )
+
+ self.n_cols = n_cols
+ self.k = k if k is not None else n_cols
+ self._name = f"hr@{k}"
+ self.reset()
+
+ def reset(self):
+ self.sum_hr = 0.0
+ self.count = 0
+
+ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+ y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+ batch_size = y_pred_reshaped.shape[0]
+ _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
+ batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ hit = (batch_relevance.sum(dim=1) > 0).float()
+ self.sum_hr += hit.sum().item()
+ self.count += batch_size
+ return np.array(self.sum_hr / max(self.count, 1))
+
+
+class Precision_at_k(Metric):
+ r"""
+ Precision at k.
+
+ Parameters
+ ----------
+ n_cols: int, default = 10
+ Number of columns in the input tensors. This parameter is neccessary
+ because the input tensors are reshaped to 2D tensors. n_cols is the
+ number of columns in the reshaped tensor. Alias for this parameter
+ are: 'n_items', 'n_items_per_query',
+ 'n_items_per_id', 'n_items_per
+ k: int, Optional, default = None
+ Number of top items to consider. It must be less than or equal to n_cols.
+ If is None, k will be equal to n_cols.
+
+ Examples
+ --------
+ >>> import torch
+ >>> from pytorch_widedeep.metrics import Precision_at_k
+ >>>
+ >>> prec_at_k = Precision_at_k(k=10)
+ >>> y_pred = torch.rand(100, 10)
+ >>> y_true = torch.randint(0, 2, (100, 10))
+ >>> score = prec_at_k(y_pred, y_true)
+ """
+
+ def __init__(self, n_cols: int = 10, k: Optional[int] = None):
+ super(Precision_at_k, self).__init__()
+
+ if k is not None and k > n_cols:
+ raise ValueError(
+ f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}"
+ )
+
+ self.n_cols = n_cols
+ self.k = k if k is not None else n_cols
+ self._name = f"precision@{k}"
+ self.reset()
+
+ def reset(self):
+ self.sum_precision = 0.0
+ self.count = 0
+
+ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+ y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+ batch_size = y_pred_reshaped.shape[0]
+ _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
+ batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ precision = batch_relevance.sum(dim=1) / self.k
+ self.sum_precision += precision.sum().item()
+ self.count += batch_size
+ return np.array(self.sum_precision / max(self.count, 1))
+
+
+class Recall_at_k(Metric):
+ r"""
+ Recall at k.
+
+ Parameters
+ ----------
+ k: int, default = 10
+ Number of top items to consider.
+
+ Examples
+ --------
+ >>> import torch
+ >>> from pytorch_widedeep.metrics import Recall_at_k
+ >>>
+ >>> rec_at_k = Recall_at_k(k=10)
+ >>> y_pred = torch.rand(100, 10)
+ >>> y_true = torch.randint(0, 2, (100, 10))
+ >>> score = rec_at_k(y_pred, y_true)
+ """
+
+ def __init__(self, n_cols: int = 10, k: Optional[int] = None):
+ super(Recall_at_k, self).__init__()
+
+ if k is not None and k > n_cols:
+ raise ValueError(
+ f"k must be less than or equal to n_cols. Got k: {k}, n_cols: {n_cols}"
+ )
+
+ self.n_cols = n_cols
+ self.k = k if k is not None else n_cols
+ self._name = f"recall@{k}"
+ self.reset()
+
+ def reset(self):
+ self.sum_recall = 0.0
+ self.count = 0
+
+ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+ y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+ batch_size = y_pred_reshaped.shape[0]
+ _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
+ batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ recall = batch_relevance.sum(dim=1) / torch.clamp(
+ y_true_reshaped.sum(dim=1), min=1
+ )
+ self.sum_recall += recall.sum().item()
+ self.count += batch_size
+ return np.array(self.sum_recall / max(self.count, 1))
+
+
+RankingMetrics = Union[
+ BinaryNDCG_at_k, MAP_at_k, HitRatio_at_k, Precision_at_k, Recall_at_k
+]
diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py
index 45cb2e9e..21d8985b 100644
--- a/pytorch_widedeep/training/_base_trainer.py
+++ b/pytorch_widedeep/training/_base_trainer.py
@@ -10,7 +10,7 @@
from torchmetrics import Metric as TorchMetric
from torch.optim.lr_scheduler import ReduceLROnPlateau
-from pytorch_widedeep.metrics import Metric, MultipleMetrics
+from pytorch_widedeep.metrics import Metric, RankingMetrics, MultipleMetrics
from pytorch_widedeep.wdtypes import (
Any,
Dict,
@@ -285,13 +285,14 @@ def _set_transforms(transforms):
else:
return None
- # this needs type fixing to adjust for the fact that the main class can
- # take an 'object', a non-instastiated Class, so, should be something like:
- # callbacks: Optional[List[Union[object, Callback]]] in all places
+ # TO DO: this needs type fixing to adjust for the fact that the main class
+ # can take an 'object', a non-instastiated Class, so, should be something
+ # like: callbacks: Optional[List[Union[object, Callback]]] in all places
def _set_callbacks_and_metrics(
self,
callbacks: Any,
- metrics: Any,
+ metrics: Any, # Union[List[Metric], List[TorchMetric]],
+ eval_metrics: Optional[Any] = None, # Union[List[Metric], List[TorchMetric]],
):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
@@ -301,9 +302,31 @@ def _set_callbacks_and_metrics(
self.callbacks.append(callback)
if metrics is not None:
self.metric = MultipleMetrics(metrics)
+ # assert that no ranking metric is used during training assertion
+ # here so all metrics are instanciated
+ assert not any(
+ [isinstance(m, RankingMetrics) for m in self.metric._metrics] # type: ignore[arg-type, misc]
+ ), "Currently, ranking metrics are not supported during training"
+
self.callbacks += [MetricCallback(self.metric)]
else:
self.metric = None
+ if eval_metrics is not None:
+ self.eval_metric = MultipleMetrics(eval_metrics)
+ # assert that if any of the metrics is a ranking metric, all metrics
+ # must be ranking metrics
+ if any(
+ [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc]
+ ):
+ assert all(
+ [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc]
+ ), (
+ "All eval metrics must be ranking metrics if any of the eval"
+ " metrics is a ranking metric"
+ )
+ self.callbacks += [MetricCallback(self.eval_metric)]
+ else:
+ self.eval_metric = None
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index 8ca7d459..cff69bfd 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -228,6 +228,7 @@ class Trainer(BaseTrainer):
"objective",
["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"],
)
+ @alias("metrics", ["train_metrics"])
def __init__(
self,
model: WideDeep,
@@ -245,6 +246,7 @@ def __init__(
transforms: Optional[List[Transforms]] = None,
callbacks: Optional[List[Callback]] = None,
metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None,
+ eval_metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None,
verbose: int = 1,
seed: int = 1,
**kwargs,
@@ -259,6 +261,7 @@ def __init__(
transforms=transforms,
callbacks=callbacks,
metrics=metrics,
+ eval_metrics=eval_metrics,
verbose=verbose,
seed=seed,
**kwargs,
@@ -600,7 +603,7 @@ def predict_uncertainty( # type: ignore[return]
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
batch_size: Optional[int] = None,
- uncertainty_granularity=1000,
+ uncertainty_granularity: int = 1000,
) -> np.ndarray:
r"""Returns the predicted ucnertainty of the model for the test dataset
using a Monte Carlo method during which dropout layers are activated
@@ -930,10 +933,10 @@ def _train_step(
if self.model.is_tabnet:
loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
- score = self._get_score(y_pred[0], y)
+ score = self._get_score(y_pred[0], y, is_train=True)
else:
loss = self.loss_fn(y_pred, y)
- score = self._get_score(y_pred, y)
+ score = self._get_score(y_pred, y, is_train=True)
loss.backward()
self.optimizer.step()
@@ -967,9 +970,9 @@ def _eval_step(
y_pred = self.model(X)
if self.model.is_tabnet:
loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
- score = self._get_score(y_pred[0], y)
+ score = self._get_score(y_pred[0], y, is_train=False)
else:
- score = self._get_score(y_pred, y)
+ score = self._get_score(y_pred, y, is_train=False)
loss = self.loss_fn(y_pred, y)
self.valid_running_loss += loss.item()
@@ -978,20 +981,30 @@ def _eval_step(
self.model.train()
return score, avg_loss
- def _get_score(self, y_pred, y):
- if self.metric is not None:
+ def _get_score(
+ self, y_pred: Tensor, y: Tensor, is_train: bool
+ ) -> Optional[Dict[str, float]]:
+
+ score = None
+ metric = None
+
+ if hasattr(self, "metric") and not hasattr(self, "eval_metric"):
+ metric = self.metric
+ elif hasattr(self, "metric") and hasattr(self, "eval_metric"):
+ metric = self.metric if is_train else self.eval_metric
+ elif not hasattr(self, "metric") and hasattr(self, "eval_metric"):
+ metric = None if is_train else self.eval_metric
+
+ if metric is not None:
if self.method == "regression":
- score = self.metric(y_pred, y)
+ score = metric(y_pred, y)
if self.method == "binary":
- score = self.metric(torch.sigmoid(y_pred), y)
+ score = metric(torch.sigmoid(y_pred), y)
if self.method == "qregression":
- score = self.metric(y_pred, y)
+ score = metric(y_pred, y)
if self.method == "multiclass":
- score = self.metric(F.softmax(y_pred, dim=1), y)
- # TO DO: handle multitarget
- return score
- else:
- return None
+ score = metric(F.softmax(y_pred, dim=1), y)
+ return score
def _predict( # type: ignore[override, return] # noqa: C901
self,
@@ -1001,7 +1014,7 @@ def _predict( # type: ignore[override, return] # noqa: C901
X_img: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
batch_size: Optional[int] = None,
- uncertainty_granularity=1000,
+ uncertainty_granularity: int = 1000,
uncertainty: bool = False,
) -> List:
r"""Private method to avoid code repetition in predict and
diff --git a/tests/test_metrics/test_ranking_metrics.py b/tests/test_metrics/test_ranking_metrics.py
new file mode 100644
index 00000000..b08ecb4a
--- /dev/null
+++ b/tests/test_metrics/test_ranking_metrics.py
@@ -0,0 +1,188 @@
+# Thanks claude 3.5 for the test cases.
+import numpy as np
+import torch
+import pytest
+
+from pytorch_widedeep.metrics import (
+ MAP_at_k,
+ NDCG_at_k,
+ Recall_at_k,
+ HitRatio_at_k,
+ Precision_at_k,
+ BinaryNDCG_at_k,
+)
+
+
+@pytest.fixture
+def setup_data():
+ y_pred = torch.tensor(
+ [
+ [0.7, 0.9, 0.5, 0.8, 0.6],
+ [0.3, 0.1, 0.5, 0.2, 0.4],
+ ]
+ )
+ y_true = torch.tensor(
+ [
+ [0, 1, 0, 1, 1],
+ [1, 1, 0, 0, 0],
+ ]
+ )
+ return y_pred, y_true
+
+
+@pytest.fixture
+def setup_data_ndcg():
+ y_pred = torch.tensor(
+ [
+ [0.7, 0.9, 0.5, 0.8, 0.6],
+ [0.3, 0.1, 0.5, 0.2, 0.4],
+ ]
+ )
+ # Using relevance scores from 0 to 3
+ y_true = torch.tensor(
+ [
+ [1, 3, 1, 2, 0],
+ [3, 1, 0, 2, 1],
+ ]
+ )
+ return y_pred, y_true
+
+
+def test_binary_ndcg_at_k(setup_data):
+ y_pred, y_true = setup_data
+ binary_ndcg = BinaryNDCG_at_k(n_cols=5, k=3)
+ result = binary_ndcg(y_pred.flatten(), y_true.flatten())
+ expected = np.array(0.5719)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+
+def test_map_at_k(setup_data):
+ y_pred, y_true = setup_data
+ map_at_k = MAP_at_k(n_cols=5, k=3)
+ result = map_at_k(y_pred.flatten(), y_true.flatten())
+ expected = np.array(0.4166)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+
+def test_hit_ratio_at_k(setup_data):
+ y_pred, y_true = setup_data
+ hr_at_k = HitRatio_at_k(n_cols=5, k=3)
+ result = hr_at_k(y_pred.flatten(), y_true.flatten())
+ expected = np.array(1.0)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+
+def test_precision_at_k(setup_data):
+ y_pred, y_true = setup_data
+ prec_at_k = Precision_at_k(n_cols=5, k=3)
+ result = prec_at_k(y_pred.flatten(), y_true.flatten())
+ expected = np.array(0.5)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+
+def test_recall_at_k(setup_data):
+ y_pred, y_true = setup_data
+ rec_at_k = Recall_at_k(n_cols=5, k=3)
+ result = rec_at_k(y_pred.flatten(), y_true.flatten())
+ expected = np.array(0.5833)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+
+def test_edge_cases_all_relevant_items():
+ # Test with all relevant items
+ y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5])
+ y_true = torch.tensor([1, 1, 1, 1, 1])
+
+ ndcg = NDCG_at_k(n_cols=5, k=3)
+ assert ndcg(y_pred, y_true) == 1.0
+
+ binary_ndcg = BinaryNDCG_at_k(n_cols=5, k=3)
+ assert binary_ndcg(y_pred, y_true) == 1.0
+
+ map_at_k = MAP_at_k(n_cols=5, k=3)
+ assert np.isclose(map_at_k(y_pred, y_true), 0.6)
+
+ hr_at_k = HitRatio_at_k(n_cols=5, k=3)
+ assert hr_at_k(y_pred, y_true) == 1.0
+
+ prec_at_k = Precision_at_k(n_cols=5, k=3)
+ assert prec_at_k(y_pred, y_true) == 1.0
+
+ rec_at_k = Recall_at_k(n_cols=5, k=3)
+ assert np.isclose(rec_at_k(y_pred, y_true), 0.6)
+
+
+def test_edge_cases_no_relevant_items():
+
+ # Test with no relevant items
+ y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5])
+ y_true = torch.tensor([0, 0, 0, 0, 0])
+
+ ndcg = NDCG_at_k(n_cols=5, k=3)
+ assert ndcg(y_pred, y_true) == 0.0
+
+ binary_ndcg = BinaryNDCG_at_k(n_cols=5, k=3)
+ assert binary_ndcg(y_pred, y_true) == 0.0
+
+ map_at_k = MAP_at_k(n_cols=5, k=3)
+ assert map_at_k(y_pred, y_true) == 0.0
+
+ hr_at_k = HitRatio_at_k(n_cols=5, k=3)
+ assert hr_at_k(y_pred, y_true) == 0.0
+
+ prec_at_k = Precision_at_k(n_cols=5, k=3)
+ assert prec_at_k(y_pred, y_true) == 0.0
+
+ rec_at_k = Recall_at_k(n_cols=5, k=3)
+ assert rec_at_k(y_pred, y_true) == 0.0
+
+
+def test_k_greater_than_n_cols():
+ with pytest.raises(ValueError):
+ NDCG_at_k(n_cols=5, k=10)
+ with pytest.raises(ValueError):
+ BinaryNDCG_at_k(n_cols=5, k=10)
+ with pytest.raises(ValueError):
+ MAP_at_k(n_cols=5, k=10)
+ with pytest.raises(ValueError):
+ HitRatio_at_k(n_cols=5, k=10)
+ with pytest.raises(ValueError):
+ Precision_at_k(n_cols=5, k=10)
+ with pytest.raises(ValueError):
+ Recall_at_k(n_cols=5, k=10)
+
+
+def test_ndcg_at_k(setup_data_ndcg):
+ y_pred, y_true = setup_data_ndcg
+ ndcg = NDCG_at_k(n_cols=5, k=3)
+ result = ndcg(y_pred.flatten(), y_true.flatten())
+ expected = np.array(0.7198)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+
+def test_ndcg_at_k_edge_cases():
+ # Test with non-decreasing ranking
+ y_pred = torch.tensor([0.5, 0.6, 0.7, 0.8, 0.9])
+ y_true = torch.tensor([0, 1, 2, 3, 4])
+
+ ndcg = NDCG_at_k(n_cols=5, k=3)
+ result = ndcg(y_pred, y_true)
+ assert result == 1.0
+
+ # Test with non-increasing ranking
+ ndcg = NDCG_at_k(n_cols=5, k=3)
+ y_pred = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5])
+ result = ndcg(y_pred, y_true)
+ expected = np.array(0.1019)
+ np.testing.assert_almost_equal(result, expected, decimal=4)
+
+ # Test with all zero relevance
+ ndcg = NDCG_at_k(n_cols=5, k=3)
+ y_true = torch.tensor([0, 0, 0, 0, 0])
+ result = ndcg(y_pred, y_true)
+ assert result == 0.0
+
+
+def test_ndcg_at_k_k_greater_than_n_cols():
+ with pytest.raises(ValueError):
+ NDCG_at_k(n_cols=5, k=10)
From d32397a36c2049ba0817f477bb1a9f0a000f4f39 Mon Sep 17 00:00:00 2001
From: Javier
Date: Sat, 19 Oct 2024 16:15:48 +0100
Subject: [PATCH 09/24] Added explanation to the metrics and a bit of
refactoring
---
pytorch_widedeep/metrics.py | 76 ++++----
pytorch_widedeep/training/_base_trainer.py | 35 ++--
pytorch_widedeep/training/trainer.py | 192 ++++++++++++++-------
tests/test_metrics/test_ranking_metrics.py | 42 +++++
4 files changed, 224 insertions(+), 121 deletions(-)
diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py
index e4b871ed..5286952b 100644
--- a/pytorch_widedeep/metrics.py
+++ b/pytorch_widedeep/metrics.py
@@ -461,13 +461,13 @@ def reset(self):
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
device = y_pred.device
- y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
- batch_size = y_true_reshaped.shape[0]
+ batch_size = y_true_2d.shape[0]
- _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
- top_k_relevance = y_true_reshaped.gather(1, top_k_indices)
+ _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
+ top_k_relevance = y_true_2d.gather(1, top_k_indices)
discounts = 1.0 / torch.log2(
torch.arange(2, top_k_relevance.shape[1] + 2, device=device)
)
@@ -475,7 +475,7 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
dcg = (torch.pow(2, top_k_relevance) - 1) * discounts.unsqueeze(0)
dcg = dcg.sum(dim=1)
- sorted_relevance, _ = torch.sort(y_true_reshaped, dim=1, descending=True)
+ sorted_relevance, _ = torch.sort(y_true_2d, dim=1, descending=True)
ideal_relevance = sorted_relevance[:, : self.k]
idcg = (torch.pow(2, ideal_relevance) - 1) * discounts[
@@ -544,13 +544,13 @@ def reset(self):
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
device = y_pred.device
- y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
- batch_size = y_pred_reshaped.shape[0]
+ batch_size = y_pred_2d.shape[0]
- _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
- top_k_mask = torch.zeros_like(y_pred_reshaped, dtype=torch.bool).scatter_(
+ _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
+ top_k_mask = torch.zeros_like(y_pred_2d, dtype=torch.bool).scatter_(
1, top_k_indices, 1
)
@@ -561,12 +561,12 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
discounts = torch.zeros_like(top_k_mask, dtype=torch.float)
discounts[top_k_mask] = expanded_discounts
- dcg = (y_true_reshaped * top_k_mask * discounts).sum(dim=1)
+ dcg = (y_true_2d * top_k_mask * discounts).sum(dim=1)
n_relevant = torch.minimum(
- y_true_reshaped.sum(dim=1), torch.tensor(self.k, device=device)
+ y_true_2d.sum(dim=1), torch.tensor(self.k, device=device)
).int()
ideal_discounts = 1.0 / torch.log2(
- torch.arange(2, y_true_reshaped.shape[1] + 2, device=device).float()
+ torch.arange(2, y_true_2d.shape[1] + 2, device=device).float()
)
idcg = torch.zeros(batch_size, device=device)
@@ -626,18 +626,18 @@ def reset(self):
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
- batch_size = y_pred_reshaped.shape[0]
- _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
- batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ batch_size = y_pred_2d.shape[0]
+ _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
+ batch_relevance = y_true_2d.gather(1, top_k_indices)
cumsum_relevance = torch.cumsum(batch_relevance, dim=1)
precision_at_i = cumsum_relevance / torch.arange(
- 1, self.k + 1, device=y_pred_reshaped.device
+ 1, self.k + 1, device=y_pred_2d.device
).float().unsqueeze(0)
avg_precision = (precision_at_i * batch_relevance).sum(dim=1) / torch.clamp(
- y_true_reshaped.sum(dim=1), min=1
+ y_true_2d.sum(dim=1), min=1
)
self.sum_avg_precision += avg_precision.sum().item()
self.count += batch_size
@@ -688,11 +688,11 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
- batch_size = y_pred_reshaped.shape[0]
- _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
- batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ batch_size = y_pred_2d.shape[0]
+ _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
+ batch_relevance = y_true_2d.gather(1, top_k_indices)
hit = (batch_relevance.sum(dim=1) > 0).float()
self.sum_hr += hit.sum().item()
self.count += batch_size
@@ -744,11 +744,11 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
- batch_size = y_pred_reshaped.shape[0]
- _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
- batch_relevance = y_true_reshaped.gather(1, top_k_indices)
+ y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ batch_size = y_pred_2d.shape[0]
+ _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
+ batch_relevance = y_true_2d.gather(1, top_k_indices)
precision = batch_relevance.sum(dim=1) / self.k
self.sum_precision += precision.sum().item()
self.count += batch_size
@@ -793,14 +793,12 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_reshaped = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_reshaped = reshape_1d_to_2d(y_true, self.n_cols)
- batch_size = y_pred_reshaped.shape[0]
- _, top_k_indices = torch.topk(y_pred_reshaped, self.k, dim=1)
- batch_relevance = y_true_reshaped.gather(1, top_k_indices)
- recall = batch_relevance.sum(dim=1) / torch.clamp(
- y_true_reshaped.sum(dim=1), min=1
- )
+ y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ batch_size = y_pred_2d.shape[0]
+ _, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
+ batch_relevance = y_true_2d.gather(1, top_k_indices)
+ recall = batch_relevance.sum(dim=1) / torch.clamp(y_true_2d.sum(dim=1), min=1)
self.sum_recall += recall.sum().item()
self.count += batch_size
return np.array(self.sum_recall / max(self.count, 1))
diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py
index 21d8985b..bf2c3875 100644
--- a/pytorch_widedeep/training/_base_trainer.py
+++ b/pytorch_widedeep/training/_base_trainer.py
@@ -302,27 +302,32 @@ def _set_callbacks_and_metrics(
self.callbacks.append(callback)
if metrics is not None:
self.metric = MultipleMetrics(metrics)
- # assert that no ranking metric is used during training assertion
- # here so all metrics are instanciated
- assert not any(
- [isinstance(m, RankingMetrics) for m in self.metric._metrics] # type: ignore[arg-type, misc]
- ), "Currently, ranking metrics are not supported during training"
-
+ if (
+ any(
+ [isinstance(m, RankingMetrics) for m in self.metric._metrics] # type: ignore[arg-type, misc]
+ )
+ and self.verbose
+ ):
+ UserWarning(
+ "There are ranking metrics in the 'metrics' list. The implementation "
+ "in this library requires that all query or user ids must have the "
+ "same number of entries or items."
+ )
self.callbacks += [MetricCallback(self.metric)]
else:
self.metric = None
if eval_metrics is not None:
self.eval_metric = MultipleMetrics(eval_metrics)
- # assert that if any of the metrics is a ranking metric, all metrics
- # must be ranking metrics
- if any(
- [isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc]
- ):
- assert all(
+ if (
+ any(
[isinstance(m, RankingMetrics) for m in self.eval_metric._metrics] # type: ignore[arg-type, misc]
- ), (
- "All eval metrics must be ranking metrics if any of the eval"
- " metrics is a ranking metric"
+ )
+ and self.verbose
+ ):
+ UserWarning(
+ "There are ranking metrics in the 'eval_metric' list. The implementation "
+ "in this library requires that all query or user ids must have the "
+ "same number of entries or items."
)
self.callbacks += [MetricCallback(self.eval_metric)]
else:
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index cff69bfd..7400409b 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -12,8 +12,10 @@
from pytorch_widedeep.losses import ZILNLoss
from pytorch_widedeep.metrics import Metric
from pytorch_widedeep.wdtypes import (
+ Any,
Dict,
List,
+ Tuple,
Union,
Tensor,
Literal,
@@ -268,6 +270,8 @@ def __init__(
)
@alias("finetune", ["warmup"])
+ @alias("train_dataloader", ["custom_dataloader"])
+ @alias("eval_dataloader", ["custom_eval_dataloader"])
def fit( # noqa: C901
self,
X_wide: Optional[np.ndarray] = None,
@@ -281,7 +285,8 @@ def fit( # noqa: C901
n_epochs: int = 1,
validation_freq: int = 1,
batch_size: int = 32,
- custom_dataloader: Optional[DataLoader] = None,
+ train_dataloader: Optional[DataLoader] = None,
+ eval_dataloader: Optional[DataLoader] = None,
feature_importance_sample_size: Optional[int] = None,
finetune: bool = False,
**kwargs,
@@ -418,11 +423,8 @@ def fit( # noqa: C901
[Examples](https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples)
folder in the repo
"""
-
dataloader_args, finetune_args = self._extract_kwargs(kwargs)
-
self.batch_size = batch_size
-
train_set, eval_set = wd_train_val_split(
self.seed,
self.method, # type: ignore
@@ -436,35 +438,18 @@ def fit( # noqa: C901
target,
self.transforms,
)
- if custom_dataloader is not None:
- # make sure is callable (and HAS to be an subclass of DataLoader)
- assert isinstance(custom_dataloader, type)
- train_loader = custom_dataloader( # type: ignore[misc]
- dataset=train_set,
- batch_size=batch_size,
- num_workers=self.num_workers,
- **dataloader_args,
- )
- else:
- train_loader = DataLoader(
- dataset=train_set,
- batch_size=batch_size,
- num_workers=self.num_workers,
- **dataloader_args,
- )
- train_steps = len(train_loader)
- if eval_set is not None:
- eval_loader = DataLoader(
- dataset=eval_set,
- batch_size=batch_size,
- num_workers=self.num_workers,
- shuffle=False,
- )
- eval_steps = len(eval_loader)
+ train_loader, eval_loader = self._get_dataloaders(
+ train_dataloader,
+ eval_dataloader,
+ train_set,
+ eval_set,
+ batch_size,
+ dataloader_args,
+ )
if finetune:
self.with_finetuning: bool = True
- self._finetune(train_loader, **finetune_args)
+ self._do_finetune(train_loader, **finetune_args)
if self.verbose:
print(
"Fine-tuning (or warmup) of individual components completed. "
@@ -474,54 +459,32 @@ def fit( # noqa: C901
self.with_finetuning = False
self.callback_container.on_train_begin(
- {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
+ {
+ "batch_size": batch_size,
+ "train_steps": len(train_loader),
+ "n_epochs": n_epochs,
+ }
)
for epoch in range(n_epochs):
- epoch_logs: Dict[str, float] = {}
- self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
-
- self.train_running_loss = 0.0
- with trange(train_steps, disable=self.verbose != 1) as t:
- for batch_idx, (data, targett) in zip(t, train_loader):
- t.set_description("epoch %i" % (epoch + 1))
- train_score, train_loss = self._train_step(data, targett, batch_idx)
- print_loss_and_metric(t, train_loss, train_score)
- self.callback_container.on_batch_end(batch=batch_idx)
- epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
-
- on_epoch_end_metric = None
- if eval_set is not None and epoch % validation_freq == (
+ epoch_logs = self._train_epoch(train_loader, epoch)
+ if eval_loader is not None and epoch % validation_freq == (
validation_freq - 1
):
- self.callback_container.on_eval_begin()
- self.valid_running_loss = 0.0
- with trange(eval_steps, disable=self.verbose != 1) as v:
- for i, (data, targett) in zip(v, eval_loader):
- v.set_description("valid")
- val_score, val_loss = self._eval_step(data, targett, i)
- print_loss_and_metric(v, val_loss, val_score)
- epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")
-
- if self.reducelronplateau:
- if self.reducelronplateau_criterion == "loss":
- on_epoch_end_metric = val_loss
- else:
- on_epoch_end_metric = val_score[
- self.reducelronplateau_criterion
- ]
+ epoch_logs, on_epoch_end_metric = self._eval_epoch(
+ eval_loader, epoch_logs
+ )
else:
+ on_epoch_end_metric = None
if self.reducelronplateau:
raise NotImplementedError(
"ReduceLROnPlateau scheduler can be used only with validation data."
)
self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
-
if self.early_stop:
# self.callback_container.on_train_end(epoch_logs)
break
self.callback_container.on_train_end(epoch_logs)
-
if feature_importance_sample_size is not None:
self.feature_importance = FeatureImportance(
self.device, feature_importance_sample_size
@@ -817,9 +780,55 @@ def save(
with open(Path(path) / "feature_importance.json", "w") as fi:
json.dump(self.feature_importance, fi)
+ def _get_dataloaders(
+ self,
+ train_dataloader: Optional[DataLoader],
+ eval_dataloader: Optional[DataLoader],
+ train_set: WideDeepDataset,
+ eval_set: Optional[WideDeepDataset],
+ batch_size: int,
+ dataloader_args: Dict[str, Any],
+ ) -> Tuple[DataLoader, Optional[DataLoader]]:
+ if train_dataloader is not None:
+ # make sure is callable (and HAS to be an subclass of DataLoader)
+ assert isinstance(train_dataloader, type)
+ train_loader = train_dataloader( # type: ignore[misc]
+ dataset=train_set,
+ batch_size=batch_size,
+ num_workers=self.num_workers,
+ **dataloader_args,
+ )
+ else:
+ train_loader = DataLoader(
+ dataset=train_set,
+ batch_size=batch_size,
+ num_workers=self.num_workers,
+ **dataloader_args,
+ )
+
+ eval_loader = None
+ if eval_set is not None:
+ if eval_dataloader is not None:
+ assert isinstance(eval_dataloader, type)
+ eval_loader = eval_dataloader( # type: ignore[misc]
+ dataset=eval_set,
+ batch_size=batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ )
+ else:
+ eval_loader = DataLoader(
+ dataset=eval_set,
+ batch_size=batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ )
+
+ return train_loader, eval_loader
+
@alias("n_epochs", ["finetune_epochs", "warmup_epochs"])
@alias("max_lr", ["finetune_max_lr", "warmup_max_lr"])
- def _finetune(
+ def _do_finetune(
self,
loader: DataLoader,
n_epochs: int = 5,
@@ -905,6 +914,27 @@ def _finetune(
self.model.deepimage, "deepimage", loader, n_epochs, max_lr
)
+ def _train_epoch(
+ self,
+ train_loader: DataLoader,
+ epoch: int,
+ ):
+ epoch_logs: Dict[str, float] = {}
+ self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
+ self.train_running_loss = 0.0
+
+ train_steps = len(train_loader)
+ with trange(train_steps, disable=self.verbose != 1) as t:
+ for batch_idx, (data, targett) in zip(t, train_loader):
+ t.set_description("epoch %i" % (epoch + 1))
+ train_score, train_loss = self._train_step(data, targett, batch_idx)
+ print_loss_and_metric(t, train_loss, train_score)
+ self.callback_container.on_batch_end(batch=batch_idx)
+
+ epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
+
+ return epoch_logs
+
def _train_step(
self,
data: Dict[str, Union[Tensor, List[Tensor]]],
@@ -946,6 +976,33 @@ def _train_step(
return score, avg_loss
+ def _eval_epoch(
+ self,
+ eval_loader: DataLoader,
+ epoch_logs: Dict[str, float],
+ ) -> Tuple[Dict[str, float], Optional[float]]:
+ self.callback_container.on_eval_begin()
+ self.valid_running_loss = 0.0
+
+ eval_steps = len(eval_loader)
+ with trange(eval_steps, disable=self.verbose != 1) as v:
+ for i, (data, targett) in zip(v, eval_loader):
+ v.set_description("valid")
+ val_score, val_loss = self._eval_step(data, targett, i)
+ print_loss_and_metric(v, val_loss, val_score)
+
+ epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")
+
+ if self.reducelronplateau:
+ if self.reducelronplateau_criterion == "loss":
+ on_epoch_end_metric = val_loss
+ else:
+ on_epoch_end_metric = val_score[self.reducelronplateau_criterion]
+ else:
+ on_epoch_end_metric = None
+
+ return epoch_logs, on_epoch_end_metric
+
def _eval_step(
self,
data: Dict[str, Union[Tensor, List[Tensor]]],
@@ -988,11 +1045,11 @@ def _get_score(
score = None
metric = None
- if hasattr(self, "metric") and not hasattr(self, "eval_metric"):
+ if self.metric and not self.eval_metric:
metric = self.metric
- elif hasattr(self, "metric") and hasattr(self, "eval_metric"):
+ elif self.metric and self.eval_metric:
metric = self.metric if is_train else self.eval_metric
- elif not hasattr(self, "metric") and hasattr(self, "eval_metric"):
+ elif not self.metric and self.eval_metric:
metric = None if is_train else self.eval_metric
if metric is not None:
@@ -1097,6 +1154,7 @@ def _predict( # type: ignore[override, return] # noqa: C901
@staticmethod
def _predict_ziln(preds: Tensor) -> Tensor:
+ # Legacy implementation. It will be removed in future versions
"""Calculates predicted mean of zero inflated lognormal logits.
Adjusted implementaion of `code
diff --git a/tests/test_metrics/test_ranking_metrics.py b/tests/test_metrics/test_ranking_metrics.py
index b08ecb4a..b221658e 100644
--- a/tests/test_metrics/test_ranking_metrics.py
+++ b/tests/test_metrics/test_ranking_metrics.py
@@ -55,6 +55,15 @@ def test_binary_ndcg_at_k(setup_data):
expected = np.array(0.5719)
np.testing.assert_almost_equal(result, expected, decimal=4)
+ # Explaining the expected value (@k=3):
+ # DCG_row1 = 0*1 + 1*0.6309 + 1*0.5 = 1.1309 -> (0.7 -> 0, 0.9 -> 1, 0.8 -> 1)
+ # DCG_row2 = 1*1 + 0*0.6309 + 0*0.5 = 1.0 -> (0.3 -> 1, 0.5 -> 0, 0.4 -> 0)
+ # IDCG_row1 = 1*1 + 1*0.6309 + 1*0.5 = 2.1309 (3 relevant items)
+ # IDCG_row2 = 1*1 + 1*0.6309 = 1.6309 (2 relevant items)
+ # NDCG_row1 = DCG_row1 / IDCG_row1 = 1.1309 / 2.1309 = 0.5309
+ # NDCG_row2 = DCG_row2 / IDCG_row2 = 1.0 / 1.6309 = 0.6129
+ # binary_NDCG = (NDCG_row1 + NDCG_row2) / 2 = (0.5309 + 0.6129) / 2 = 0.5719
+
def test_map_at_k(setup_data):
y_pred, y_true = setup_data
@@ -63,6 +72,12 @@ def test_map_at_k(setup_data):
expected = np.array(0.4166)
np.testing.assert_almost_equal(result, expected, decimal=4)
+ # Explaining the expected value (@k=3):
+ # batch relevance for row top 3 predictions: [[1, 1, 0], [0, 0, 1]]
+ # AP_row1 = (1/1 + 2/2 + 0/3) / 3 = 0.6666 -> [(1/1) * 1 + (2/2) * 1 + (2/3) * 0] / 3
+ # AP_row2 = (0/1 + 0/2 + 1/3) / 2 = 0.1666 -> [(0/1) * 0 + (0/2) * 0 + (1/3) * 1] / 2
+ # MAP = (AP_row1 + AP_row2) / 2 = (0.6666 + 0.1666) / 2 = 0.4166
+
def test_hit_ratio_at_k(setup_data):
y_pred, y_true = setup_data
@@ -79,6 +94,14 @@ def test_precision_at_k(setup_data):
expected = np.array(0.5)
np.testing.assert_almost_equal(result, expected, decimal=4)
+ # Explaining the expected value (@k=3):
+ # batch relevance for row top 3 predictions: [[1, 1, 0], [0, 0, 1]]
+ # Precision_row1 = 2 / 3 = 0.666
+ # Precision_row2 = 1 / 3 = 0.333
+ # Precision = (Precision_row1 + Precision_row2) / 2 = (0.666 + 0.333) / 2 = 0.5
+ # caveat: if k is higher than the number of relevant items for a given row
+ # this metric will never be 1.0
+
def test_recall_at_k(setup_data):
y_pred, y_true = setup_data
@@ -87,6 +110,12 @@ def test_recall_at_k(setup_data):
expected = np.array(0.5833)
np.testing.assert_almost_equal(result, expected, decimal=4)
+ # Explaining the expected value (@k=3):
+ # batch relevance for row top 3 predictions: [[1, 1, 0], [0, 0, 1]]
+ # Recall_row1 = 2 / 3 = 0.666
+ # Recall_row2 = 1 / 2 = 0.5
+ # Recall = (Recall_row1 + Recall_row2) / 2 = (0.666 + 0.5) / 2 = 0.5833
+
def test_edge_cases_all_relevant_items():
# Test with all relevant items
@@ -159,6 +188,19 @@ def test_ndcg_at_k(setup_data_ndcg):
expected = np.array(0.7198)
np.testing.assert_almost_equal(result, expected, decimal=4)
+ # Explaining the expected value (@k=3):
+ # top 3 predictions for row 1: [0.9, 0.8, 0.7] -> [3, 2, 1]
+ # top 3 predictions for row 2: [0.5, 0.4, 0.3] -> [0, 1, 3]
+ # DCG = sum[(2^rel - 1) / log2(rank + 1)]
+ # DCG_row1 = 2^3 - 1 / log2(1+1) + 2^2 - 1 / log2(2+1) + 2^1 - 1 / log2(3+1) = 7.0 + 1.8928 + 0.5000 = 9.3928
+ # DCG_row2 = 2^0 - 1 / log2(1+1) + 2^1 - 1 / log2(2+1) + 2^3 - 1 / log2(3+1) = 0.0 + 0.6309 + 3.5 = 4.1309
+ # IDCG = sum[(2^rel - 1) / log2(rank + 1)] for the ideal ranking at k=3 [[3, 2, 1], [3, 2, 1]]
+ # IDCG_row1 = 2^3 - 1 / log2(1+1) + 2^2 - 1 / log2(2+1) + 2^1 - 1 / log2(3+1) = 7.0 + 1.8928 + 0.5000 = 9.3928
+ # IDCG_row2 = IDCG_row1 = 9.3928
+ # NDGC_row1 = DCG_row1 / IDCG_row1 = 9.3928 / 9.3928 = 1.0
+ # NDGC_row2 = DCG_row2 / IDCG_row2 = 4.1309 / 9.3928 = 0.4398
+ # NDCG = (NDGC_row1 + NDGC_row2) / 2 = (1.0 + 0.4398) / 2 = 0.7198
+
def test_ndcg_at_k_edge_cases():
# Test with non-decreasing ranking
From f091194bf69be95558be667ea337f2ff518f1bda Mon Sep 17 00:00:00 2001
From: Javier
Date: Sat, 19 Oct 2024 23:11:08 +0100
Subject: [PATCH 10/24] First integration test run. I will have to adjust all
examples containing the class
---
pytest.ini | 2 +
pytorch_widedeep/dataloaders.py | 79 +++++++--
pytorch_widedeep/metrics.py | 66 +++++---
pytorch_widedeep/training/trainer.py | 67 +++-----
.../test_ranking_metrics_integration.py | 153 ++++++++++++++++++
5 files changed, 285 insertions(+), 82 deletions(-)
create mode 100644 pytest.ini
create mode 100644 tests/test_rec/test_ranking_metrics_integration.py
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 00000000..c1fa8785
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts = -p no:warnings
\ No newline at end of file
diff --git a/pytorch_widedeep/dataloaders.py b/pytorch_widedeep/dataloaders.py
index 3c17bc5c..db53ace0 100644
--- a/pytorch_widedeep/dataloaders.py
+++ b/pytorch_widedeep/dataloaders.py
@@ -1,4 +1,4 @@
-from typing import Tuple
+from typing import Tuple, Optional
import numpy as np
from torch.utils.data import DataLoader, WeightedRandomSampler
@@ -6,6 +6,45 @@
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
+class DatasetAlreadySetError(Exception):
+ """Exception raised when attempting to set a dataset that has already been set."""
+
+ def __init__(self, message="Dataset has already been set and cannot be changed."):
+ self.message = message
+ super().__init__(self.message)
+
+
+class CustomDataLoader(DataLoader):
+ r"""
+ Wrapper around the `torch.utils.data.DataLoader` class that allows
+ to set the dataset after the class has been instantiated.
+ """
+
+ def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs):
+ self.dataset_set = dataset is not None
+ if self.dataset_set:
+ super().__init__(dataset, *args, **kwargs)
+ else:
+ self.args = args
+ self.kwargs = kwargs
+
+ def set_dataset(self, dataset: WideDeepDataset):
+ if self.dataset_set:
+ raise DatasetAlreadySetError()
+
+ self.dataset_set = True
+ super().__init__(dataset, *self.args, **self.kwargs)
+
+ def __iter__(self):
+ if not self.dataset_set:
+ raise ValueError(
+ "Dataset has not been set. Use set_dataset method to set a dataset."
+ )
+ return super().__iter__()
+
+
+# From here on is legacy code and there are better ways to do it. It will be
+# removed in the next version.
def get_class_weights(dataset: WideDeepDataset) -> Tuple[np.ndarray, int, int]:
"""Helper function to get weights of classes in the imbalanced dataset.
@@ -37,7 +76,7 @@ def get_class_weights(dataset: WideDeepDataset) -> Tuple[np.ndarray, int, int]:
return weights, minor_class_count, num_classes
-class DataLoaderImbalanced(DataLoader):
+class DataLoaderImbalanced(CustomDataLoader):
r"""Class to load and shuffle batches with adjusted weights for imbalanced
datasets. If the classes do not begin from 0 remapping is necessary. See
[here](https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab).
@@ -46,13 +85,11 @@ class DataLoaderImbalanced(DataLoader):
----------
dataset: `WideDeepDataset`
see `pytorch_widedeep.training._wd_dataset`
- batch_size: int
- size of batch
- num_workers: int
- number of workers
Other Parameters
----------------
+ *args: Any
+ Positional arguments to be passed to the parent CustomDataLoader.
**kwargs: Dict
This can include any parameter that can be passed to the _'standard'_
pytorch
@@ -69,23 +106,31 @@ class DataLoaderImbalanced(DataLoader):
$$
"""
- def __init__(
- self, dataset: WideDeepDataset, batch_size: int, num_workers: int, **kwargs
- ):
+ def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ if dataset is not None:
+ self._setup_sampler(dataset)
+ super().__init__(dataset, *args, sampler=self.sampler, **kwargs)
+ else:
+ super().__init__()
+
+ def set_dataset(self, dataset: WideDeepDataset):
+ sampler = self._setup_sampler(dataset)
+ # update the kwargs with the new sampler
+ self.kwargs["sampler"] = sampler
+ super().set_dataset(dataset)
+
+ def _setup_sampler(self, dataset: WideDeepDataset) -> WeightedRandomSampler:
assert dataset.Y is not None, (
"The 'dataset' instance of WideDeepDataset must contain a "
"target array 'Y'"
)
- if "oversample_mul" in kwargs:
- oversample_mul = kwargs["oversample_mul"]
- del kwargs["oversample_mul"]
- else:
- oversample_mul = 1
+ oversample_mul = self.kwargs.pop("oversample_mul", 1)
weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
samples_weight = list(np.array([weights[i] for i in dataset.Y]))
sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True)
- super().__init__(
- dataset, batch_size, num_workers=num_workers, sampler=sampler, **kwargs
- )
+ return sampler
diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py
index 5286952b..c708ff9b 100644
--- a/pytorch_widedeep/metrics.py
+++ b/pytorch_widedeep/metrics.py
@@ -397,15 +397,25 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
return np.array((1 - (self.numerator / self.denominator)))
-def reshape_1d_to_2d(tensor: Tensor, n_columns: int) -> Tensor:
- if tensor.dim() != 1:
- raise ValueError("Input tensor must be 1-dimensional")
- if tensor.size(0) % n_columns != 0:
+def reshape_to_2d(tensor: Tensor, n_columns: int) -> Tensor:
+ if tensor.dim() == 1:
+ if tensor.size(0) % n_columns != 0:
+ raise ValueError(
+ f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})"
+ )
+ n_rows = tensor.size(0) // n_columns
+ return tensor.reshape(n_rows, n_columns)
+ elif tensor.dim() == 2 and tensor.size(1) == 1:
+ if tensor.size(0) % n_columns != 0:
+ raise ValueError(
+ f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})"
+ )
+ n_rows = tensor.size(0) // n_columns
+ return tensor.reshape(n_rows, n_columns)
+ else:
raise ValueError(
- f"Tensor length ({tensor.size(0)}) must be divisible by n_columns ({n_columns})"
+ "Input tensor must be 1-dimensional or 2-dimensional with one column"
)
- n_rows = tensor.size(0) // n_columns
- return tensor.reshape(n_rows, n_columns)
class NDCG_at_k(Metric):
@@ -461,8 +471,8 @@ def reset(self):
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
device = y_pred.device
- y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_to_2d(y_true, self.n_cols)
batch_size = y_true_2d.shape[0]
@@ -544,8 +554,8 @@ def reset(self):
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
device = y_pred.device
- y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_to_2d(y_true, self.n_cols)
batch_size = y_pred_2d.shape[0]
@@ -607,6 +617,9 @@ class MAP_at_k(Metric):
>>> score = map_at_k(y_pred, y_true)
"""
+ @alias(
+ "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"]
+ )
def __init__(self, n_cols: int = 10, k: Optional[int] = None):
super(MAP_at_k, self).__init__()
@@ -626,8 +639,8 @@ def reset(self):
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_to_2d(y_true, self.n_cols)
batch_size = y_pred_2d.shape[0]
_, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
@@ -670,6 +683,9 @@ class HitRatio_at_k(Metric):
>>> score = hr_at_k(y_pred, y_true
"""
+ @alias(
+ "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"]
+ )
def __init__(self, n_cols: int = 10, k: Optional[int] = None):
super(HitRatio_at_k, self).__init__()
@@ -688,8 +704,8 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_to_2d(y_true, self.n_cols)
batch_size = y_pred_2d.shape[0]
_, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
batch_relevance = y_true_2d.gather(1, top_k_indices)
@@ -726,6 +742,9 @@ class Precision_at_k(Metric):
>>> score = prec_at_k(y_pred, y_true)
"""
+ @alias(
+ "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"]
+ )
def __init__(self, n_cols: int = 10, k: Optional[int] = None):
super(Precision_at_k, self).__init__()
@@ -744,8 +763,8 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_to_2d(y_true, self.n_cols)
batch_size = y_pred_2d.shape[0]
_, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
batch_relevance = y_true_2d.gather(1, top_k_indices)
@@ -761,6 +780,12 @@ class Recall_at_k(Metric):
Parameters
----------
+ n_cols: int, default = 10
+ Number of columns in the input tensors. This parameter is neccessary
+ because the input tensors are reshaped to 2D tensors. n_cols is the
+ number of columns in the reshaped tensor. Alias for this parameter
+ are: 'n_items', 'n_items_per_query',
+ 'n_items_per_id', 'n_items_per_user'
k: int, default = 10
Number of top items to consider.
@@ -775,6 +800,9 @@ class Recall_at_k(Metric):
>>> score = rec_at_k(y_pred, y_true)
"""
+ @alias(
+ "n_cols", ["n_items", "n_items_per_query", "n_items_per_id", "n_items_per_user"]
+ )
def __init__(self, n_cols: int = 10, k: Optional[int] = None):
super(Recall_at_k, self).__init__()
@@ -793,8 +821,8 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
- y_pred_2d = reshape_1d_to_2d(y_pred, self.n_cols)
- y_true_2d = reshape_1d_to_2d(y_true, self.n_cols)
+ y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
+ y_true_2d = reshape_to_2d(y_true, self.n_cols)
batch_size = y_pred_2d.shape[0]
_, top_k_indices = torch.topk(y_pred_2d, self.k, dim=1)
batch_relevance = y_true_2d.gather(1, top_k_indices)
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index 7400409b..bd871f3b 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -26,6 +26,7 @@
LRScheduler,
)
from pytorch_widedeep.callbacks import Callback
+from pytorch_widedeep.dataloaders import CustomDataLoader, DatasetAlreadySetError
from pytorch_widedeep.initializers import Initializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import alias, to_device
@@ -285,8 +286,8 @@ def fit( # noqa: C901
n_epochs: int = 1,
validation_freq: int = 1,
batch_size: int = 32,
- train_dataloader: Optional[DataLoader] = None,
- eval_dataloader: Optional[DataLoader] = None,
+ train_dataloader: Optional[CustomDataLoader] = None,
+ eval_dataloader: Optional[CustomDataLoader] = None,
feature_importance_sample_size: Optional[int] = None,
finetune: bool = False,
**kwargs,
@@ -438,13 +439,13 @@ def fit( # noqa: C901
target,
self.transforms,
)
- train_loader, eval_loader = self._get_dataloaders(
- train_dataloader,
- eval_dataloader,
- train_set,
- eval_set,
- batch_size,
- dataloader_args,
+ train_loader = self._set_dataloader(
+ train_dataloader, train_set, batch_size, dataloader_args
+ )
+ eval_loader = (
+ self._set_dataloader(eval_dataloader, eval_set, batch_size, dataloader_args)
+ if eval_set is not None
+ else None
)
if finetune:
@@ -780,51 +781,25 @@ def save(
with open(Path(path) / "feature_importance.json", "w") as fi:
json.dump(self.feature_importance, fi)
- def _get_dataloaders(
+ def _set_dataloader(
self,
- train_dataloader: Optional[DataLoader],
- eval_dataloader: Optional[DataLoader],
- train_set: WideDeepDataset,
- eval_set: Optional[WideDeepDataset],
+ dataloader: Optional[CustomDataLoader],
+ dataset: WideDeepDataset,
batch_size: int,
dataloader_args: Dict[str, Any],
- ) -> Tuple[DataLoader, Optional[DataLoader]]:
- if train_dataloader is not None:
- # make sure is callable (and HAS to be an subclass of DataLoader)
- assert isinstance(train_dataloader, type)
- train_loader = train_dataloader( # type: ignore[misc]
- dataset=train_set,
- batch_size=batch_size,
- num_workers=self.num_workers,
- **dataloader_args,
- )
+ ) -> DataLoader | CustomDataLoader:
+ if dataloader is not None:
+ dataloader.set_dataset(dataset)
+ return dataloader
else:
- train_loader = DataLoader(
- dataset=train_set,
+ # var name 'loader' to avoid reassigment and type errors
+ loader = DataLoader(
+ dataset=dataset,
batch_size=batch_size,
num_workers=self.num_workers,
**dataloader_args,
)
-
- eval_loader = None
- if eval_set is not None:
- if eval_dataloader is not None:
- assert isinstance(eval_dataloader, type)
- eval_loader = eval_dataloader( # type: ignore[misc]
- dataset=eval_set,
- batch_size=batch_size,
- num_workers=self.num_workers,
- shuffle=False,
- )
- else:
- eval_loader = DataLoader(
- dataset=eval_set,
- batch_size=batch_size,
- num_workers=self.num_workers,
- shuffle=False,
- )
-
- return train_loader, eval_loader
+ return loader
@alias("n_epochs", ["finetune_epochs", "warmup_epochs"])
@alias("max_lr", ["finetune_max_lr", "warmup_max_lr"])
diff --git a/tests/test_rec/test_ranking_metrics_integration.py b/tests/test_rec/test_ranking_metrics_integration.py
new file mode 100644
index 00000000..1239ae54
--- /dev/null
+++ b/tests/test_rec/test_ranking_metrics_integration.py
@@ -0,0 +1,153 @@
+# silence DeprecationWarning
+import warnings
+from typing import Tuple
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from pytorch_widedeep.models import TabMlp, WideDeep
+from pytorch_widedeep.metrics import (
+ F1Score,
+ Accuracy,
+ MAP_at_k,
+ NDCG_at_k,
+ Recall_at_k,
+ HitRatio_at_k,
+ Precision_at_k,
+ BinaryNDCG_at_k,
+)
+from pytorch_widedeep.training import Trainer
+from pytorch_widedeep.preprocessing import TabPreprocessor
+from pytorch_widedeep.training._wd_dataset import WideDeepDataset
+
+
+def generate_user_item_interactions(
+ n_users: int = 10,
+ n_interactions_per_user: int = 10,
+ n_items: int = 5,
+ random_seed: int = 42,
+) -> pd.DataFrame:
+ np.random.seed(random_seed)
+
+ user_ids = np.repeat(range(1, n_users + 1), n_interactions_per_user)
+ item_ids = np.random.randint(1, n_items + 1, size=n_users * n_interactions_per_user)
+ user_categories = np.random.choice(
+ ["A", "B", "C"], size=n_users * n_interactions_per_user
+ )
+ item_categories = np.random.choice(
+ ["X", "Y", "Z"], size=n_users * n_interactions_per_user
+ )
+ binary_target = np.random.randint(0, 2, size=n_users * n_interactions_per_user)
+ categorical_target = np.random.randint(0, 3, size=n_users * n_interactions_per_user)
+
+ df = pd.DataFrame(
+ {
+ "user_id": user_ids,
+ "item_id": item_ids,
+ "user_category": user_categories,
+ "item_category": item_categories,
+ "liked": binary_target,
+ "rating": categorical_target,
+ }
+ )
+
+ return df
+
+
+def split_train_validation(
+ df: pd.DataFrame, validation_interactions_per_user: int
+) -> Tuple[pd.DataFrame, pd.DataFrame]:
+ grouped = (
+ df.groupby("user_id")
+ .apply(lambda x: x.sample(frac=1, random_state=42))
+ .reset_index(drop=True)
+ )
+
+ train_df = (
+ grouped.groupby("user_id")
+ .apply(lambda x: x.iloc[:-validation_interactions_per_user])
+ .reset_index(drop=True)
+ )
+ val_df = (
+ grouped.groupby("user_id")
+ .apply(lambda x: x.iloc[-validation_interactions_per_user:])
+ .reset_index(drop=True)
+ )
+
+ return train_df, val_df
+
+
+@pytest.fixture
+def user_item_data(request):
+ validation_interactions_per_user = request.param.get(
+ "validation_interactions_per_user", 5
+ )
+ df = generate_user_item_interactions()
+ train_df, val_df = split_train_validation(df, validation_interactions_per_user)
+ return train_df, val_df
+
+
+# strategy 1:
+# * ranking metrics for both training and validation sets.
+# * Same number of items per user in both sets
+
+
+@pytest.mark.parametrize(
+ "user_item_data", [{"validation_interactions_per_user": 5}], indirect=True
+)
+@pytest.mark.parametrize(
+ "metric",
+ [
+ MAP_at_k(n_cols=5, k=3),
+ NDCG_at_k(n_cols=5, k=3),
+ Recall_at_k(n_cols=5, k=3),
+ HitRatio_at_k(n_cols=5, k=3),
+ Precision_at_k(n_cols=5, k=3),
+ BinaryNDCG_at_k(n_cols=5, k=3),
+ ],
+)
+def test_binary_classification_strategy_1(user_item_data, metric):
+
+ train_df, val_df = user_item_data
+
+ categorical_cols = ["user_id", "item_id", "user_category", "item_category"]
+
+ target_col = "liked"
+
+ tab_preprocessor = TabPreprocessor(
+ embed_cols=categorical_cols,
+ for_transformer=False,
+ )
+
+ tab_preprocessor.fit(train_df)
+
+ X_train = tab_preprocessor.transform(train_df)
+ X_val = tab_preprocessor.transform(val_df)
+
+ tab_mlp = TabMlp(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type]
+ )
+ model = WideDeep(deeptabular=tab_mlp)
+
+ trainer = Trainer(model, objective="binary", metrics=[metric])
+
+ trainer.fit(
+ X_train={"X_tab": X_train, "target": train_df[target_col].values},
+ X_val={"X_tab": X_val, "target": val_df[target_col].values},
+ n_epochs=1,
+ batch_size=2 * 5,
+ )
+
+ # predict on validation, this is just a test...
+ preds = trainer.predict(X_tab=X_val)
+
+ assert preds.shape[0] == X_val.shape[0]
+ assert (
+ trainer.history is not None
+ and "train_loss" in trainer.history
+ and "val_loss" in trainer.history
+ and f"train_{metric._name}" in trainer.history
+ and f"val_{metric._name}" in trainer.history
+ )
From 62db3ca23cc540dfec6de0afba583afd0d441942 Mon Sep 17 00:00:00 2001
From: Javier
Date: Mon, 21 Oct 2024 09:57:14 +0100
Subject: [PATCH 11/24] All test passed. I need to update examples for the
Imbalanced Data Loader. Then review all the docs and cleans/rmove the docs
folder
---
.gitignore | 11 +-
.isort.cfg | 4 -
.travis.yml | 34 --
code_style.sh | 8 -
pytest.ini | 2 -
pytorch_widedeep/metrics.py | 10 +
pytorch_widedeep/training/_base_trainer.py | 3 +-
pytorch_widedeep/training/trainer.py | 2 +-
.../training/trainer_from_folder.py | 330 ++---------------
.../test_fit_methods.py | 4 +-
.../test_ranking_metrics_integration.py | 339 +++++++++++++++++-
11 files changed, 382 insertions(+), 365 deletions(-)
delete mode 100644 .isort.cfg
delete mode 100644 .travis.yml
delete mode 100755 code_style.sh
delete mode 100644 pytest.ini
diff --git a/.gitignore b/.gitignore
index 06c4cc68..710f3c01 100644
--- a/.gitignore
+++ b/.gitignore
@@ -60,4 +60,13 @@ checkpoints
# wnb
wandb/
-wandb_api.key
\ No newline at end of file
+wandb_api.key
+
+# pytest
+pytest.ini
+
+# code style
+code_style.sh
+
+# isort
+.isort.cfg
diff --git a/.isort.cfg b/.isort.cfg
deleted file mode 100644
index 023dbdb4..00000000
--- a/.isort.cfg
+++ /dev/null
@@ -1,4 +0,0 @@
-[settings]
-profile=black
-multi_line_output=3
-length_sort=1
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 75550745..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,34 +0,0 @@
-dist: xenial
-language: python
-python:
- - "3.7.9"
- - "3.8"
- - "3.9"
-matrix:
- fast_finish: true
- include:
- - name: "Code Style (Black/Flake8)"
- install:
- - pip install black
- - pip install flake8
- script:
- # Black code style
- - black --check --diff pytorch_widedeep tests examples setup.py
- # Stop the build if there are Python syntax errors or undefined names
- - flake8 . --count --select=E901,E999,F821,F822,F823 --ignore=E266 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --ignore=E203,E266,E501,E722,F401,F403,F405,F811,W503,C901 --statistics
- after_success: skip
-install:
- - pip install --upgrade pip
- - pip install pytest-cov
- - pip install codecov
- - pip install .
-script:
- - pytest --doctest-modules pytorch_widedeep --cov-report html --cov-report term --disable-pytest-warnings --cov=pytorch_widedeep tests/
- # - pytest tests
-after_success:
- # Combine coverage reports from Travis jobs
- - coverage combine
- # Send reports to Codecov. It automatically handles merging coverage results
- - codecov
\ No newline at end of file
diff --git a/code_style.sh b/code_style.sh
deleted file mode 100755
index 52eb2d64..00000000
--- a/code_style.sh
+++ /dev/null
@@ -1,8 +0,0 @@
-# sort imports
-isort --quiet . pytorch_widedeep tests examples setup.py
-# Black code style
-black . pytorch_widedeep tests examples setup.py
-# flake8 standards
-flake8 . --max-complexity=10 --max-line-length=127 --ignore=E203,E266,E501,E722,E721,F401,F403,F405,W503,C901,F811
-# mypy
-mypy pytorch_widedeep --ignore-missing-imports --no-strict-optional
\ No newline at end of file
diff --git a/pytest.ini b/pytest.ini
deleted file mode 100644
index c1fa8785..00000000
--- a/pytest.ini
+++ /dev/null
@@ -1,2 +0,0 @@
-[pytest]
-addopts = -p no:warnings
\ No newline at end of file
diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py
index c708ff9b..772f4c57 100644
--- a/pytorch_widedeep/metrics.py
+++ b/pytorch_widedeep/metrics.py
@@ -469,8 +469,18 @@ def reset(self):
self.count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
+ # NDGC@k is supposed to be used when the output reflects interest
+ # scores, i.e, could be used in a regression or a multiclass problem.
+ # If regression y_pred will be a float tensor, if multiclass, y_pred
+ # will be a float tensor with the output of a softmax activation
+ # function and we need to turn it into a 1D tensor with the class.
+ # Finally, for binary problems, please use BinaryNDCG_at_k
device = y_pred.device
+ if y_pred.ndim > 1 and y_pred.size(1) > 1:
+ # multiclass
+ y_pred = y_pred.topk(1, 1)[1]
+
y_pred_2d = reshape_to_2d(y_pred, self.n_cols)
y_true_2d = reshape_to_2d(y_true, self.n_cols)
diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py
index bf2c3875..2f15a3bb 100644
--- a/pytorch_widedeep/training/_base_trainer.py
+++ b/pytorch_widedeep/training/_base_trainer.py
@@ -63,6 +63,7 @@ def __init__(
transforms: Optional[List[Transforms]],
callbacks: Optional[List[Callback]],
metrics: Optional[Union[List[Metric], List[TorchMetric]]],
+ eval_metrics: Optional[Union[List[Metric], List[TorchMetric]]],
verbose: int,
seed: int,
**kwargs,
@@ -90,7 +91,7 @@ def __init__(
self.optimizer = self._set_optimizer(optimizers)
self.lr_scheduler = self._set_lr_scheduler(lr_schedulers, **kwargs)
self.transforms = self._set_transforms(transforms)
- self._set_callbacks_and_metrics(callbacks, metrics)
+ self._set_callbacks_and_metrics(callbacks, metrics, eval_metrics)
@abstractmethod
def fit(self, **kwargs):
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index bd871f3b..e259268d 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -26,7 +26,7 @@
LRScheduler,
)
from pytorch_widedeep.callbacks import Callback
-from pytorch_widedeep.dataloaders import CustomDataLoader, DatasetAlreadySetError
+from pytorch_widedeep.dataloaders import CustomDataLoader
from pytorch_widedeep.initializers import Initializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import alias, to_device
diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py
index 1ebc2629..fac5a925 100644
--- a/pytorch_widedeep/training/trainer_from_folder.py
+++ b/pytorch_widedeep/training/trainer_from_folder.py
@@ -13,7 +13,6 @@
List,
Union,
Tensor,
- Literal,
Optional,
WideDeep,
Optimizer,
@@ -22,30 +21,12 @@
)
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.initializers import Initializer
-from pytorch_widedeep.training._finetune import FineTune
+from pytorch_widedeep.training.trainer import Trainer
from pytorch_widedeep.utils.general_utils import alias, to_device
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
-from pytorch_widedeep.training._base_trainer import BaseTrainer
-from pytorch_widedeep.training._trainer_utils import (
- save_epoch_logs,
- print_loss_and_metric,
-)
-
-# Observation 1: I am annoyed by sublime highlighting an override issue with
-# the abstractmethods. There is no override issue. The Signature of
-# the 'predict' method is compatible with the supertype. Buy for whatever
-# issue sublime highlights this as an error (not vscode and is not returned
-# as an error when running mypy). I am ignoring it
-# There is a lot of code repetition between this class and the 'Trainer'
-# class (and in consquence a lot of ignore methods for test coverage). Maybe
-# in the future I decided to merge the two of them and offer the ability to
-# laod from folder based on the input parameters. For now, I'll leave it like
-# this, separated, since it is the easiest and most manageable(i.e. easier to
-# debug) implementation
-
-class TrainerFromFolder(BaseTrainer):
+class TrainerFromFolder(Trainer):
r"""Class to set the of attributes that will be used during the
training process.
@@ -187,6 +168,7 @@ class TrainerFromFolder(BaseTrainer):
"objective",
["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"],
)
+ @alias("metrics", ["train_metrics"])
def __init__(
self,
model: WideDeep,
@@ -204,6 +186,7 @@ def __init__(
transforms: Optional[List[Transforms]] = None,
callbacks: Optional[List[Callback]] = None,
metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None,
+ eval_metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None,
verbose: int = 1,
seed: int = 1,
**kwargs,
@@ -218,13 +201,19 @@ def __init__(
transforms=transforms,
callbacks=callbacks,
metrics=metrics,
+ eval_metrics=eval_metrics,
verbose=verbose,
seed=seed,
**kwargs,
)
+ if self.method == "multitarget":
+ raise NotImplementedError(
+ "Training from folder is not supported for multitarget models"
+ )
+
@alias("finetune", ["warmup"])
- def fit( # noqa: C901
+ def fit( # type: ignore[override] # noqa: C901
self,
train_loader: DataLoader,
eval_loader: Optional[DataLoader] = None,
@@ -233,68 +222,44 @@ def fit( # noqa: C901
finetune: bool = False,
**kwargs,
):
- finetune_args = self._extract_kwargs(kwargs)
- train_steps = len(train_loader)
+ # There will never be dataloader_args when using 'TrainingFromFolder'
+ # as the loaders are passed as arguments
+ _, finetune_args = self._extract_kwargs(kwargs)
if finetune:
- self._finetune(train_loader, **finetune_args)
+ self.with_finetuning: bool = True
+ self._do_finetune(train_loader, **finetune_args)
if self.verbose:
print(
"Fine-tuning (or warmup) of individual components completed. "
"Training the whole model for {} epochs".format(n_epochs)
)
+ else:
+ self.with_finetuning = False
self.callback_container.on_train_begin(
{
"batch_size": train_loader.batch_size,
- "train_steps": train_steps,
+ "train_steps": len(train_loader),
"n_epochs": n_epochs,
}
)
for epoch in range(n_epochs):
- epoch_logs: Dict[str, float] = {}
- self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
-
- self.train_running_loss = 0.0
- with trange(train_steps, disable=self.verbose != 1) as t:
- for batch_idx, (data, targett) in zip(t, train_loader):
- t.set_description("epoch %i" % (epoch + 1))
- train_score, train_loss = self._train_step(
- data, targett, batch_idx, epoch
- )
- print_loss_and_metric(t, train_loss, train_score)
- self.callback_container.on_batch_end(batch=batch_idx)
- epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
-
- on_epoch_end_metric = None
+ epoch_logs = self._train_epoch(train_loader, epoch)
if eval_loader is not None and epoch % validation_freq == (
validation_freq - 1
):
- eval_steps = len(eval_loader)
- self.callback_container.on_eval_begin()
- self.valid_running_loss = 0.0
- with trange(eval_steps, disable=self.verbose != 1) as v:
- for i, (data, targett) in zip(v, eval_loader):
- v.set_description("valid")
- val_score, val_loss = self._eval_step(data, targett, i)
- print_loss_and_metric(v, val_loss, val_score)
- epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")
-
- if self.reducelronplateau: # pragma: no cover
- if self.reducelronplateau_criterion == "loss":
- on_epoch_end_metric = val_loss
- else:
- on_epoch_end_metric = val_score[
- self.reducelronplateau_criterion
- ]
+ epoch_logs, on_epoch_end_metric = self._eval_epoch(
+ eval_loader, epoch_logs
+ )
else:
+ on_epoch_end_metric = None
if self.reducelronplateau:
raise NotImplementedError(
"ReduceLROnPlateau scheduler can be used only with validation data."
)
self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
-
if self.early_stop:
# self.callback_container.on_train_end(epoch_logs)
break
@@ -327,7 +292,7 @@ def predict( # type: ignore[override, return]
preds = np.vstack(preds_l)
return np.argmax(preds, 1)
- def predict_uncertainty( # type: ignore[return]
+ def predict_uncertainty( # type: ignore[override, return]
self,
X_wide: Optional[np.ndarray] = None,
X_tab: Optional[np.ndarray] = None,
@@ -402,202 +367,7 @@ def predict_proba( # type: ignore[override, return]
if self.method == "multiclass":
return np.vstack(preds_l)
- def save(
- self,
- path: str,
- save_state_dict: bool = False,
- save_optimizer: bool = False,
- model_filename: str = "wd_model.pt",
- ): # pragma: no cover
- """
- Parameters
- ----------
- path: str
- path to the directory where the model and the feature importance
- attribute will be saved.
- save_state_dict: bool, default = False
- Boolean indicating whether to save directly the model
- (and optimizer) or the model's (and optimizer's) state
- dictionary
- save_optimizer: bool, default = False
- Boolean indicating whether to save the optimizer
- model_filename: str, Optional, default = "wd_model.pt"
- filename where the model weights will be store
- """
- self._save_history(path)
-
- self._save_model_and_optimizer(
- path, save_state_dict, save_optimizer, model_filename
- )
-
- @alias("n_epochs", ["finetune_epochs", "warmup_epochs"])
- @alias("max_lr", ["finetune_max_lr", "warmup_max_lr"])
- def _finetune(
- self,
- loader: DataLoader,
- n_epochs: int = 5,
- max_lr: float = 0.01,
- routine: Literal["howard", "felbo"] = "howard",
- deeptabular_gradual: bool = False,
- deeptabular_layers: Optional[List[nn.Module]] = None,
- deeptabular_max_lr: float = 0.01,
- deeptext_gradual: bool = False,
- deeptext_layers: Optional[List[nn.Module]] = None,
- deeptext_max_lr: float = 0.01,
- deepimage_gradual: bool = False,
- deepimage_layers: Optional[List[nn.Module]] = None,
- deepimage_max_lr: float = 0.01,
- ):
- r"""
- Simple wrap-up to individually fine-tune model components
- """
- if self.model.deephead is not None:
- raise ValueError(
- "Currently warming up is only supported without a fully connected 'DeepHead'"
- )
-
- finetuner = FineTune(self.loss_fn, self.metric, self.method, self.verbose) # type: ignore[arg-type]
- if self.model.wide:
- finetuner.finetune_all(self.model.wide, "wide", loader, n_epochs, max_lr)
-
- if self.model.deeptabular:
- if deeptabular_gradual:
- assert (
- deeptabular_layers is not None
- ), "deeptabular_layers must be passed if deeptabular_gradual=True"
- finetuner.finetune_gradual(
- self.model.deeptabular,
- "deeptabular",
- loader,
- deeptabular_max_lr,
- deeptabular_layers,
- routine,
- )
- else:
- finetuner.finetune_all(
- self.model.deeptabular, "deeptabular", loader, n_epochs, max_lr
- )
-
- if self.model.deeptext:
- if deeptext_gradual:
- assert (
- deeptext_layers is not None
- ), "deeptext_layers must be passed if deeptabular_gradual=True"
- finetuner.finetune_gradual(
- self.model.deeptext,
- "deeptext",
- loader,
- deeptext_max_lr,
- deeptext_layers,
- routine,
- )
- else:
- finetuner.finetune_all(
- self.model.deeptext, "deeptext", loader, n_epochs, max_lr
- )
-
- if self.model.deepimage:
- if deepimage_gradual:
- assert (
- deepimage_layers is not None
- ), "deepimage_layers must be passed if deeptabular_gradual=True"
- finetuner.finetune_gradual(
- self.model.deepimage,
- "deepimage",
- loader,
- deepimage_max_lr,
- deepimage_layers,
- routine,
- )
- else:
- finetuner.finetune_all(
- self.model.deepimage, "deepimage", loader, n_epochs, max_lr
- )
-
- def _train_step(
- self,
- data: Dict[str, Union[Tensor, List[Tensor]]],
- target: Tensor,
- batch_idx: int,
- epoch: int,
- ):
- self.model.train()
- X: Dict[str, Union[Tensor, List[Tensor]]] = {}
- for k, v in data.items():
- if isinstance(v, list):
- X[k] = [to_device(i, self.device) for i in v]
- else:
- X[k] = to_device(v, self.device)
- y = (
- target.view(-1, 1).float()
- if self.method not in ["multiclass", "qregression"]
- else target
- )
- y = to_device(y, self.device)
-
- self.optimizer.zero_grad()
-
- y_pred = self.model(X)
-
- if self.model.is_tabnet: # pragma: no cover
- loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
- score = self._get_score(y_pred[0], y)
- else:
- loss = self.loss_fn(y_pred, y)
- score = self._get_score(y_pred, y)
-
- loss.backward()
- self.optimizer.step()
-
- self.train_running_loss += loss.item()
- avg_loss = self.train_running_loss / (batch_idx + 1)
-
- return score, avg_loss
-
- def _eval_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
- self.model.eval()
- with torch.no_grad():
- X: Dict[str, Union[Tensor, List[Tensor]]] = {}
- for k, v in data.items():
- if isinstance(v, list):
- X[k] = [to_device(i, self.device) for i in v]
- else:
- X[k] = to_device(v, self.device)
- y = (
- target.view(-1, 1).float()
- if self.method not in ["multiclass", "qregression"]
- else target
- )
- y = to_device(y, self.device)
-
- y_pred = self.model(X)
- if self.model.is_tabnet: # pragma: no cover
- loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
- score = self._get_score(y_pred[0], y)
- else:
- score = self._get_score(y_pred, y)
- loss = self.loss_fn(y_pred, y)
-
- self.valid_running_loss += loss.item()
- avg_loss = self.valid_running_loss / (batch_idx + 1)
-
- return score, avg_loss
-
- def _get_score(self, y_pred, y): # pragma: no cover
- if self.metric is not None:
- if self.method == "regression":
- score = self.metric(y_pred, y)
- if self.method == "binary":
- score = self.metric(torch.sigmoid(y_pred), y)
- if self.method == "qregression":
- score = self.metric(y_pred, y)
- if self.method == "multiclass":
- score = self.metric(F.softmax(y_pred, dim=1), y)
- return score
- else:
- return None
-
- def _predict( # noqa: C901
+ def _predict( # type: ignore[override] # noqa: C901
self,
X_wide: Optional[np.ndarray] = None,
X_tab: Optional[np.ndarray] = None,
@@ -606,7 +376,7 @@ def _predict( # noqa: C901
X_test: Optional[Dict[str, Union[np.ndarray, List[np.ndarray]]]] = None,
test_loader: Optional[DataLoader] = None,
batch_size: Optional[int] = None,
- uncertainty_granularity=1000,
+ uncertainty_granularity: int = 1000,
uncertainty: bool = False,
) -> List:
r"""Private method to avoid code repetition in predict and
@@ -690,49 +460,3 @@ def _predict( # noqa: C901
preds_l.append(preds)
self.model.train()
return preds_l
-
- @staticmethod
- def _predict_ziln(preds: Tensor) -> Tensor: # pragma: no cover
- """Calculates predicted mean of zero inflated lognormal logits.
-
- Adjusted implementaion of `code
- `
-
- Arguments:
- preds: [batch_size, 3] tensor of logits.
- Returns:
- ziln_preds: [batch_size, 1] tensor of predicted mean.
- """
- positive_probs = torch.sigmoid(preds[..., :1])
- loc = preds[..., 1:2]
- scale = F.softplus(preds[..., 2:])
- ziln_preds = positive_probs * torch.exp(loc + 0.5 * torch.square(scale))
- return ziln_preds
-
- @staticmethod
- def _extract_kwargs(kwargs):
- finetune_params = [
- "n_epochs",
- "finetune_epochs",
- "warmup_epochs",
- "max_lr",
- "finetune_max_lr",
- "warmup_max_lr",
- "routine",
- "deeptabular_gradual",
- "deeptabular_layers",
- "deeptabular_max_lr",
- "deeptext_gradual",
- "deeptext_layers",
- "deeptext_max_lr",
- "deepimage_gradual",
- "deepimage_layers",
- "deepimage_max_lr",
- ]
-
- finetune_args = {}
- for k, v in kwargs.items():
- if k in finetune_params:
- finetune_args[k] = v
-
- return finetune_args
diff --git a/tests/test_model_functioning/test_fit_methods.py b/tests/test_model_functioning/test_fit_methods.py
index f873f75d..a1188c02 100644
--- a/tests/test_model_functioning/test_fit_methods.py
+++ b/tests/test_model_functioning/test_fit_methods.py
@@ -286,12 +286,12 @@ def test_custom_dataloader():
)
model = WideDeep(wide=wide, deeptabular=deeptabular)
trainer = Trainer(model, loss="binary", verbose=0)
+
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=target_binary_imbalanced,
- batch_size=16,
- custom_dataloader=DataLoaderImbalanced,
+ custom_dataloader=DataLoaderImbalanced(batch_size=16),
)
# simply checking that runs with DataLoaderImbalanced
assert "train_loss" in trainer.history.keys()
diff --git a/tests/test_rec/test_ranking_metrics_integration.py b/tests/test_rec/test_ranking_metrics_integration.py
index 1239ae54..81f8881b 100644
--- a/tests/test_rec/test_ranking_metrics_integration.py
+++ b/tests/test_rec/test_ranking_metrics_integration.py
@@ -1,5 +1,3 @@
-# silence DeprecationWarning
-import warnings
from typing import Tuple
import numpy as np
@@ -18,8 +16,8 @@
BinaryNDCG_at_k,
)
from pytorch_widedeep.training import Trainer
+from pytorch_widedeep.dataloaders import CustomDataLoader
from pytorch_widedeep.preprocessing import TabPreprocessor
-from pytorch_widedeep.training._wd_dataset import WideDeepDataset
def generate_user_item_interactions(
@@ -89,8 +87,8 @@ def user_item_data(request):
# strategy 1:
-# * ranking metrics for both training and validation sets.
-# * Same number of items per user in both sets
+# * Same ranking metrics for both training and validation sets.
+# * Same number of items per user in both sets
@pytest.mark.parametrize(
@@ -100,11 +98,14 @@ def user_item_data(request):
"metric",
[
MAP_at_k(n_cols=5, k=3),
- NDCG_at_k(n_cols=5, k=3),
Recall_at_k(n_cols=5, k=3),
HitRatio_at_k(n_cols=5, k=3),
Precision_at_k(n_cols=5, k=3),
BinaryNDCG_at_k(n_cols=5, k=3),
+ [Accuracy(), MAP_at_k(n_cols=5, k=3)],
+ [F1Score(), MAP_at_k(n_cols=5, k=3)],
+ [Accuracy(), BinaryNDCG_at_k(n_cols=5, k=3)],
+ [F1Score(), BinaryNDCG_at_k(n_cols=5, k=3)],
],
)
def test_binary_classification_strategy_1(user_item_data, metric):
@@ -131,13 +132,17 @@ def test_binary_classification_strategy_1(user_item_data, metric):
)
model = WideDeep(deeptabular=tab_mlp)
- trainer = Trainer(model, objective="binary", metrics=[metric])
+ if isinstance(metric, list):
+ trainer = Trainer(model, objective="binary", metrics=metric)
+ else:
+ trainer = Trainer(model, objective="binary", metrics=[metric])
trainer.fit(
X_train={"X_tab": X_train, "target": train_df[target_col].values},
X_val={"X_tab": X_val, "target": val_df[target_col].values},
n_epochs=1,
batch_size=2 * 5,
+ verbose=0,
)
# predict on validation, this is just a test...
@@ -148,6 +153,322 @@ def test_binary_classification_strategy_1(user_item_data, metric):
trainer.history is not None
and "train_loss" in trainer.history
and "val_loss" in trainer.history
- and f"train_{metric._name}" in trainer.history
- and f"val_{metric._name}" in trainer.history
)
+ if not isinstance(metric, list):
+ metric = [metric]
+ for m in metric:
+ assert f"train_{m._name}" in trainer.history
+ assert f"val_{m._name}" in trainer.history
+
+
+@pytest.mark.parametrize(
+ "user_item_data", [{"validation_interactions_per_user": 5}], indirect=True
+)
+def test_multiclass_classification_strategy_1(user_item_data):
+
+ train_df, val_df = user_item_data
+
+ categorical_cols = ["user_id", "item_id", "user_category", "item_category"]
+
+ target_col = "rating"
+
+ tab_preprocessor = TabPreprocessor(
+ embed_cols=categorical_cols,
+ for_transformer=False,
+ )
+
+ tab_preprocessor.fit(train_df)
+
+ X_train = tab_preprocessor.transform(train_df)
+ X_val = tab_preprocessor.transform(val_df)
+
+ tab_mlp = TabMlp(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type]
+ )
+ model = WideDeep(deeptabular=tab_mlp, pred_dim=3)
+
+ trainer = Trainer(
+ model,
+ objective="multiclass",
+ metrics=[Accuracy(), NDCG_at_k(n_cols=5, k=3)],
+ verbose=0,
+ )
+
+ trainer.fit(
+ X_train={"X_tab": X_train, "target": train_df[target_col].values},
+ X_val={"X_tab": X_val, "target": val_df[target_col].values},
+ n_epochs=1,
+ batch_size=2 * 5, # 2 * n_items
+ )
+
+ # predict on validation, this is just a test...
+ preds = trainer.predict(X_tab=X_val)
+
+ assert preds.shape[0] == X_val.shape[0]
+ assert (
+ trainer.history is not None
+ and "train_loss" in trainer.history
+ and "val_loss" in trainer.history
+ )
+
+
+# strategy 2:
+# * Diff metrics for training and validation sets.
+# * Same number of items per user in both sets
+
+
+@pytest.mark.parametrize(
+ "user_item_data", [{"validation_interactions_per_user": 5}], indirect=True
+)
+@pytest.mark.parametrize("target_col", ["liked", "rating"])
+def test_strategy_2(user_item_data, target_col):
+
+ train_df, val_df = user_item_data
+
+ categorical_cols = ["user_id", "item_id", "user_category", "item_category"]
+
+ tab_preprocessor = TabPreprocessor(
+ embed_cols=categorical_cols,
+ for_transformer=False,
+ )
+
+ tab_preprocessor.fit(train_df)
+
+ X_train = tab_preprocessor.transform(train_df)
+ X_val = tab_preprocessor.transform(val_df)
+
+ tab_mlp = TabMlp(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type]
+ )
+ model = WideDeep(deeptabular=tab_mlp, pred_dim=1 if target_col == "liked" else 3)
+
+ eval_metrics = (
+ [BinaryNDCG_at_k(n_cols=5, k=3)]
+ if target_col == "liked"
+ else [NDCG_at_k(n_cols=5, k=3)]
+ )
+ trainer = Trainer(
+ model,
+ objective="binary" if target_col == "liked" else "multiclass",
+ metrics=[Accuracy()],
+ eval_metrics=eval_metrics,
+ verbose=0,
+ )
+
+ trainer.fit(
+ X_train={"X_tab": X_train, "target": train_df[target_col].values},
+ X_val={"X_tab": X_val, "target": val_df[target_col].values},
+ n_epochs=1,
+ batch_size=2 * 5,
+ )
+
+ # predict on validation, this is just a test...
+ preds = trainer.predict(X_tab=X_val)
+
+ assert preds.shape[0] == X_val.shape[0]
+ assert (
+ trainer.history is not None
+ and "train_loss" in trainer.history
+ and "val_loss" in trainer.history
+ and "train_acc" in trainer.history
+ and "val_binary_ndcg@3" in trainer.history
+ if target_col == "liked"
+ else "val_ndcg@3" in trainer.history
+ )
+
+
+# strategy 3:
+# * Diff number of items per user in both sets which implies CustomDataLoaders
+# * everything else
+
+
+@pytest.mark.parametrize(
+ "user_item_data", [{"validation_interactions_per_user": 6}], indirect=True
+)
+@pytest.mark.parametrize(
+ "metrics",
+ [
+ [[BinaryNDCG_at_k(n_cols=4, k=3)], [BinaryNDCG_at_k(n_cols=6, k=3)]],
+ [[Accuracy(), F1Score()], [F1Score(), BinaryNDCG_at_k(n_cols=6, k=3)]],
+ [[Accuracy(), F1Score()], [F1Score(), MAP_at_k(n_cols=6, k=3)]],
+ [[Accuracy(), F1Score()], [F1Score(), Precision_at_k(n_cols=6, k=3)]],
+ [[Accuracy(), F1Score()], [F1Score(), Recall_at_k(n_cols=6, k=3)]],
+ [[Accuracy(), F1Score()], [F1Score(), HitRatio_at_k(n_cols=6, k=3)]],
+ ],
+)
+@pytest.mark.parametrize("with_train_data_loader", [True, False])
+def test_binary_classification_strategy_3(
+ user_item_data, metrics, with_train_data_loader
+):
+
+ train_metrics, val_metrics = metrics[0], metrics[1]
+
+ train_df, val_df = user_item_data
+
+ categorical_cols = ["user_id", "item_id", "user_category", "item_category"]
+
+ target_col = "liked"
+
+ tab_preprocessor = TabPreprocessor(
+ embed_cols=categorical_cols,
+ for_transformer=False,
+ )
+
+ tab_preprocessor.fit(train_df)
+
+ X_train = tab_preprocessor.transform(train_df)
+ X_val = tab_preprocessor.transform(val_df)
+
+ tab_mlp = TabMlp(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type]
+ )
+ model = WideDeep(deeptabular=tab_mlp)
+
+ trainer = Trainer(
+ model,
+ objective="binary",
+ metrics=train_metrics,
+ eval_metrics=val_metrics,
+ )
+
+ if with_train_data_loader:
+ train_dl = CustomDataLoader(
+ batch_size=2 * 4, shuffle=False # in case there is a ranking metric
+ )
+ else:
+ train_dl = None
+
+ valid_dl = CustomDataLoader(
+ batch_size=2 * 6,
+ shuffle=False,
+ )
+
+ trainer.fit(
+ X_train={"X_tab": X_train, "target": train_df[target_col].values},
+ X_val={"X_tab": X_val, "target": val_df[target_col].values},
+ n_epochs=1,
+ batch_size=2
+ * 4, # it only applies to the training set. Will be ignored if train_dl is not None
+ train_dataloader=train_dl,
+ eval_dataloader=valid_dl,
+ )
+
+ train_metric_names = [m._name for m in train_metrics]
+ val_metric_names = [m._name for m in val_metrics]
+
+ # predict on validation, this is just a test...
+ preds = trainer.predict(X_tab=X_val)
+
+ assert valid_dl.dataset.X_tab.shape[0] == X_val.shape[0] == 60
+
+ if with_train_data_loader:
+ assert train_dl.dataset.X_tab.shape[0] == X_train.shape[0] == 40
+
+ assert preds.shape[0] == X_val.shape[0]
+ assert (
+ trainer.history is not None
+ and "train_loss" in trainer.history
+ and "val_loss" in trainer.history
+ )
+
+ for train_metric_name in train_metric_names:
+ assert f"train_{train_metric_name}" in trainer.history
+
+ for val_metric_name in val_metric_names:
+ assert f"val_{val_metric_name}" in trainer.history
+
+
+@pytest.mark.parametrize(
+ "user_item_data", [{"validation_interactions_per_user": 6}], indirect=True
+)
+@pytest.mark.parametrize(
+ "metrics",
+ [
+ [[NDCG_at_k(n_cols=4, k=3)], [NDCG_at_k(n_cols=6, k=3)]],
+ [[Accuracy(), F1Score()], [F1Score(), NDCG_at_k(n_cols=6, k=3)]],
+ ],
+)
+@pytest.mark.parametrize("with_train_data_loader", [True, False])
+def test_multiclass_classification_strategy_3(
+ user_item_data, metrics, with_train_data_loader
+):
+
+ train_metrics, val_metrics = metrics[0], metrics[1]
+
+ train_df, val_df = user_item_data
+
+ categorical_cols = ["user_id", "item_id", "user_category", "item_category"]
+
+ target_col = "rating"
+
+ tab_preprocessor = TabPreprocessor(
+ embed_cols=categorical_cols,
+ for_transformer=False,
+ )
+
+ tab_preprocessor.fit(train_df)
+
+ X_train = tab_preprocessor.transform(train_df)
+ X_val = tab_preprocessor.transform(val_df)
+
+ tab_mlp = TabMlp(
+ column_idx=tab_preprocessor.column_idx,
+ cat_embed_input=tab_preprocessor.cat_embed_input, # type: ignore[arg-type]
+ )
+ model = WideDeep(deeptabular=tab_mlp, pred_dim=3)
+
+ trainer = Trainer(
+ model,
+ objective="multiclass",
+ metrics=train_metrics,
+ eval_metrics=val_metrics,
+ )
+
+ if with_train_data_loader:
+ train_dl = CustomDataLoader(
+ batch_size=2 * 4, shuffle=False # in case there is a ranking metric
+ )
+ else:
+ train_dl = None
+
+ valid_dl = CustomDataLoader(
+ batch_size=2 * 6,
+ shuffle=False,
+ )
+
+ trainer.fit(
+ X_train={"X_tab": X_train, "target": train_df[target_col].values},
+ X_val={"X_tab": X_val, "target": val_df[target_col].values},
+ n_epochs=1,
+ batch_size=2
+ * 4, # it only applies to the training set. Will be ignored if train_dl is not None
+ train_dataloader=train_dl,
+ eval_dataloader=valid_dl,
+ )
+
+ train_metric_names = [m._name for m in train_metrics]
+ val_metric_names = [m._name for m in val_metrics]
+
+ # predict on validation, this is just a test...
+ preds = trainer.predict(X_tab=X_val)
+
+ assert valid_dl.dataset.X_tab.shape[0] == X_val.shape[0] == 60
+
+ if with_train_data_loader:
+ assert train_dl.dataset.X_tab.shape[0] == X_train.shape[0] == 40
+
+ assert preds.shape[0] == X_val.shape[0]
+ assert (
+ trainer.history is not None
+ and "train_loss" in trainer.history
+ and "val_loss" in trainer.history
+ )
+
+ for train_metric_name in train_metric_names:
+ assert f"train_{train_metric_name}" in trainer.history
+
+ for val_metric_name in val_metric_names:
+ assert f"val_{val_metric_name}" in trainer.history
From 808a7bbbedbc6908ca8c61966ce5e0488c391e6b Mon Sep 17 00:00:00 2001
From: Javier
Date: Mon, 21 Oct 2024 10:33:48 +0100
Subject: [PATCH 12/24] remove unnecessary, legacy docs folder and modified
figures path in README
---
mkdocs/site/index.html | 3 +-
mkdocs/site/objects.inv | Bin 1996 -> 2007 bytes
.../pytorch-widedeep/bayesian_trainer.html | 33 +-
mkdocs/site/pytorch-widedeep/dataloaders.html | 87 ++--
mkdocs/site/pytorch-widedeep/losses.html | 432 +++++++++--------
mkdocs/site/pytorch-widedeep/metrics.html | 72 +--
.../site/pytorch-widedeep/preprocessing.html | 297 ++++++++++++
mkdocs/site/pytorch-widedeep/preprocessing.md | 2 +
mkdocs/site/pytorch-widedeep/tab2vec.html | 73 +--
mkdocs/site/pytorch-widedeep/trainer.html | 435 ++++++++----------
mkdocs/site/search/search_index.json | 2 +-
mkdocs/site/sitemap.xml | 88 ++--
mkdocs/site/sitemap.xml.gz | Bin 796 -> 795 bytes
13 files changed, 869 insertions(+), 655 deletions(-)
diff --git a/mkdocs/site/index.html b/mkdocs/site/index.html
index 1f68a10a..666a04ab 100644
--- a/mkdocs/site/index.html
+++ b/mkdocs/site/index.html
@@ -2382,8 +2382,7 @@ APAJavier,
Pavol Mulinka,
- Javier Rodriguez Zaurin,
- Not Committed Yet
+ Javier Rodriguez Zaurin
diff --git a/mkdocs/site/objects.inv b/mkdocs/site/objects.inv
index 2818ee0861c603228b2f2d166adb8ebdc801c0d6..2208df9aa6290a7b647eaabf35faa393dd5d954b 100644
GIT binary patch
delta 1881
zcmV-f2d4PU57!TnlYdQ)d&?x_N~#=J)=teV`~F~l|(+~*RLeNU)V-CYB84t
zWa#&4b+-TsyL?xLpzHXRC6EA=(UQD_WQ50gkw7LRMV3`Y#I`+S9Cr83nbEgq?{B^H
zseJ#FW<<(QU-hOn|G72)<F3KYjOT^TD`vvj+NH7L2P8TTX}~SwRvIGP>2<
zQ@eeiFG)r?h2*Tj?#ZVm%n@-aiabt>OdlJuoUEY!yVrhiS|++-V)xHVcZ?3uZ`fH!6T?7OcU0*JpRvB@0?s^j{tKH;|h3kZW
zDJ!4ZiIkNOi>Xh>DGKCY72%3y;0<)wpl?{OV-nV2H1`_!Iac7GUI7y3FYY_`e53wz
zKb?8wY>jz_jPoi}OcAjH6+M{kC%b)Pcv;FQF#^84IHfj+l>n`Wrlc$TQ0bZ{zMafL
z2}WgfO@C=62t7ZA&>rk=F4E+n`U3B6ck*>X3jv-e8*APnzkwoNIX4DNRaPoa
zSZ>6KQo@f`eMVYY5j$f!sbtSAjyILiUPYxS{(lA9p9_Npb$RUV-?euc9s$&;(dKYnq=jSOI*Za!L70sd$Ex4V&
z)_(=APu3E4spzI|y>1&HrjJNo4`PV6N4x@Ct(EjUHeCjVHB#=xwV
z4nC@d-nBzf??%_{-h4L0@V`eAJsCF)^@O3(|7Nx@-R(-@Cb$)qiqL0B;{Co})U83H>`V|Y
z=7qyjA5A)CYJV)bbVy7jArKOuu;N;*GB^%5mtdI%Ehm#`bbx
z>zmccBju%=vGcE^{7Z-;roEGP96aD1icqQFPx1>9px$QKC$w%An#VQdf`4#S#;Xcf
zfRYHsNEQbK8KwKj?D=Eb*rR|0MWFG>fy6t3ed$D2C5RU$Au%p8ZcZ_;1PLO|)X?VZ
zpikz*p4-nugBoOG(=SHUn(pfh&q({IB6qPf2^ULvgqmpNI!`j
z#L?rC8O`Y}&JgeGX*LnpGk<+z<9AOn;Cgvn9L=pU0@(v{#8>@5cl83)JjY%A0DAu2
z;pW)Wb$?*ZjNYRoZZ?jsd>hvRgbqU|(AWq0Eb5?+7X`Nix7}Cx4$#jvH65$>=^N
zeFUtNe*XEa#D^G+?i2(9uGKh4+1hof8b=0~A~}$Csfl6g
zy(Mr+;F$3PDO?i#1c_TdU&iB*zNe^8T_pUm?J5i*bg8F-NTjq7JK-Yu782k7d{tBmpdCgW%maSB!47PIs{xGZ*G*RMO(+t;!Ls_btpIIjWLx>CDVJSgt0Qn~qkss4hQ
z1@1u3cEwQrv@8Af=CWd05<|X(B+<8VGR}z-?9EMc$IbxB6(~FMOH7fHB(a45D$sYv
Ts6x_)E*=`LV`BdUpl4e@;*O~J
delta 1870
zcmV-U2eJ6q56lmclYb@0y=9WPk}4;awNtZId!b~}*xG<#C6SN$^(zVR7q$_OTFfN@
z8Tx%%-7P@EE+cK8R2nWB#_BSkwujev1!j3hwXiHX7s(;`&;jP
zDL?+C8IkhKh2FI0zqaPT?VDezJp28RFJFHv2ZQ9<)(<^&8*?lX`UW#=B$DH|{N`dYBS7VE=hz?wa)tJ}LB=>5X@7
z4Fz)+PcQtUwic0aA3T=Zu1O@|83OSP3Hh#Cc&xOwHGkij_dG2gul@DvmObC79;*wd
ztbArCQdU09r#=~{D3E_uge#VTH_&Z^zG1zNNmzr?%xm0dSb=|j14x*^xNq6>jr!01
z^wAqi1!z4qC0*J3N>@Db?PLy0
zFe;-fN`EUs=;xtXHd
zY@h5+g(v5NuJ;trr
zJy)P>Sv`V!E5wU;o%hDE+jk`(hamp7)631Oj$O>@UBnrlpT}fe?<>z&G>byC;CA|2
z7k{`uSxeZZqU*Z#s%?CjJ|THMh#}e>@dj+UQqu3(bRHDeNV!cTAxrmA=DI0RfE?6A
zhQR7@I0+OyG$jZM?6n4YYIi~f=CJx`VL5^mJSaQ!
z5@ZqhqH&sFwBVdTLJN?2nQK)omy9n{LLuIPTBY^dkkQ@5PFwko%^>kkSP*(baDOga
zomYtUZgkb|%^n{y{O^%OPsTMvJz;3{znd*gce_%!32sHDBJ>54c(K1GQ}^V9!rv!8Bz+?-%8lwuwP3+u6GyN5P9gt
zfy`hZa5`hWdC?;aG(e@9yySBC$MjT$f^YK$|NMlMaIo3=KUN&q?xB#e>d~leAsjQxo=Q|
ztZn+mh+5Nqec>5t_iLXc^vOYhv3i4U0u(20gzAqqlv5U6o&o76k%KsTI5ML-y}=pc
zeLc-4;(De}to`mO23#)>i+`iJHAWzNKo0n-ALuS$ftu&I%O5~bzsTDhd%W)Vthw9_
zoMqJg+B)34Pt$E&);|H`F#KEjVDP0`Jf`ky1+{m;8G-evMx0-Oue@R_6Zy_ejtTQ
zf*&Dq%jffW9Mbm~)v=3&Kela!A%rgVG!W@n`Iw&ukoJp=?eDtPEmN|9ta-3FdDmhO
z|JF5d{KDrV>2ha-r+@7Jpn=3Ncrp=*W59ei5&?P-n773V+w`=Io{JaYy-Vn|Uia)z
z=ta$JY}|bDzu#RPd9&gaPX9+i$s$l!Tdjx+APNF|=su$PlRR@6@dv=G90a2z)A(^b
diff --git a/mkdocs/site/pytorch-widedeep/bayesian_trainer.html b/mkdocs/site/pytorch-widedeep/bayesian_trainer.html
index 51e79f11..2b758b4b 100644
--- a/mkdocs/site/pytorch-widedeep/bayesian_trainer.html
+++ b/mkdocs/site/pytorch-widedeep/bayesian_trainer.html
@@ -1755,7 +1755,8 @@
num_workers: int
@@ -1794,8 +1795,7 @@
Source code in pytorch_widedeep/training/bayesian_trainer.py
- 111
-112
+ 112
113
114
115
@@ -1822,7 +1822,8 @@
136
137
138
-139 | @alias( # noqa: C901
+139
+140
| @alias( # noqa: C901
"objective",
["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"],
)
@@ -1998,8 +1999,7 @@
Source code in pytorch_widedeep/training/bayesian_trainer.py
- 141
-142
+ 142
143
144
145
@@ -2112,7 +2112,8 @@
252
253
254
-255 | def fit( # noqa: C901
+255
+256
| def fit( # noqa: C901
self,
X_tab: np.ndarray,
target: np.ndarray,
@@ -2308,8 +2309,7 @@
Source code in pytorch_widedeep/training/bayesian_trainer.py
- 257
-258
+ 258
259
260
261
@@ -2344,7 +2344,8 @@ 290
291
292
-293 | def predict( # type: ignore[return]
+293
+294
| def predict( # type: ignore[return]
self,
X_tab: np.ndarray,
n_samples: int = 5,
@@ -2462,8 +2463,7 @@
Source code in pytorch_widedeep/training/bayesian_trainer.py
- 295
-296
+ 296
297
298
299
@@ -2505,7 +2505,8 @@ 335
336
337
-338 | def predict_proba( # type: ignore[return]
+338
+339
| def predict_proba( # type: ignore[return]
self,
X_tab: np.ndarray,
n_samples: int = 5,
@@ -2618,8 +2619,7 @@
Source code in pytorch_widedeep/training/bayesian_trainer.py
- 340
-341
+ 341
342
343
344
@@ -2665,7 +2665,8 @@
384
385
386
-387 | | def save(
self,
path: str,
save_state_dict: bool = False,
diff --git a/mkdocs/site/pytorch-widedeep/dataloaders.html b/mkdocs/site/pytorch-widedeep/dataloaders.html
index 45f846f4..3453782e 100644
--- a/mkdocs/site/pytorch-widedeep/dataloaders.html
+++ b/mkdocs/site/pytorch-widedeep/dataloaders.html
@@ -1610,14 +1610,12 @@ ¶
-DataLoaderImbalanced(
- dataset, batch_size, num_workers, **kwargs
-)
+DataLoaderImbalanced(dataset=None, *args, **kwargs)
- Bases: DataLoader
+ Bases: CustomDataLoader
Class to load and shuffle batches with adjusted weights for imbalanced
@@ -1629,33 +1627,26 @@
dataset
- (WideDeepDataset )
+ (Optional[WideDeepDataset] , default:
+ None
+)
–
see pytorch_widedeep.training._wd_dataset
-
- batch_size
- (int )
- –
-
-
-
- num_workers
- (int )
- –
-
-
Other Parameters:
+ -
+
*args
+ –
+
+ Positional arguments to be passed to the parent CustomDataLoader.
+
+
-
**kwargs
–
@@ -1677,45 +1668,23 @@
Source code in pytorch_widedeep/dataloaders.py
- 72
-73
-74
-75
-76
-77
-78
-79
-80
-81
-82
-83
-84
-85
-86
-87
-88
-89
-90
-91 | def __init__(
- self, dataset: WideDeepDataset, batch_size: int, num_workers: int, **kwargs
-):
- assert dataset.Y is not None, (
- "The 'dataset' instance of WideDeepDataset must contain a "
- "target array 'Y'"
- )
-
- if "oversample_mul" in kwargs:
- oversample_mul = kwargs["oversample_mul"]
- del kwargs["oversample_mul"]
+ 109
+110
+111
+112
+113
+114
+115
+116
+117 | def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ if dataset is not None:
+ self._setup_sampler(dataset)
+ super().__init__(dataset, *args, sampler=self.sampler, **kwargs)
else:
- oversample_mul = 1
- weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
- num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
- samples_weight = list(np.array([weights[i] for i in dataset.Y]))
- sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True)
- super().__init__(
- dataset, batch_size, num_workers=num_workers, sampler=sampler, **kwargs
- )
+ super().__init__()
|
diff --git a/mkdocs/site/pytorch-widedeep/losses.html b/mkdocs/site/pytorch-widedeep/losses.html
index 39ad8311..91b430e8 100644
--- a/mkdocs/site/pytorch-widedeep/losses.html
+++ b/mkdocs/site/pytorch-widedeep/losses.html
@@ -2756,8 +2756,7 @@
278
279
280
-281
-282
| def forward(self, input: Tensor, target: Tensor) -> Tensor:
+281
| def forward(self, input: Tensor, target: Tensor) -> Tensor:
r"""
Parameters
----------
@@ -2789,8 +2788,7 @@
else:
num_class = input_prob.size(1)
binary_target = torch.eye(num_class)[target.squeeze().cpu().long()]
- if use_cuda:
- binary_target = binary_target.cuda()
+ binary_target = binary_target.to(input.device)
binary_target = binary_target.contiguous()
weight = self._get_weight(input_prob, binary_target)
@@ -2834,8 +2832,8 @@
Source code in pytorch_widedeep/losses.py
- | def __init__(self):
+ | def __init__(self):
super().__init__()
|
@@ -2898,7 +2896,8 @@
Source code in pytorch_widedeep/losses.py
- 314
+ 313
+314
315
316
317
@@ -2914,8 +2913,7 @@
327
328
329
-330
-331 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
+330
| def forward(self, input: Tensor, target: Tensor) -> Tensor:
r"""
Parameters
----------
@@ -2971,8 +2969,8 @@
Source code in pytorch_widedeep/losses.py
- | def __init__(self):
+ | def __init__(self):
super().__init__()
|
@@ -3047,7 +3045,8 @@
Source code in pytorch_widedeep/losses.py
- 346
+ 345
+346
347
348
349
@@ -3082,8 +3081,7 @@
378
379
380
-381
-382 | | def forward(
self,
input: Tensor,
target: Tensor,
@@ -3158,8 +3156,8 @@
Source code in pytorch_widedeep/losses.py
- | def __init__(self):
+ | def __init__(self):
super().__init__()
|
@@ -3223,7 +3221,8 @@
Source code in pytorch_widedeep/losses.py
- 396
+ 395
+396
397
398
399
@@ -3271,10 +3270,7 @@
441
442
443
-444
-445
-446
-447 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
+444
| def forward(self, input: Tensor, target: Tensor) -> Tensor:
r"""
Parameters
----------
@@ -3311,9 +3307,7 @@
# when using max the two input tensors (input and other) have to be of
# the same type
max_input = F.softplus(input[..., 2:])
- max_other = torch.sqrt(torch.Tensor([torch.finfo(torch.double).eps])).type(
- max_input.type()
- )
+ max_other = self.get_eps(max_input)
scale = torch.max(max_input, max_other)
safe_labels = positive * target + (1 - positive) * torch.ones_like(target)
@@ -3361,8 +3355,8 @@
Source code in pytorch_widedeep/losses.py
- | def __init__(self):
+ | def __init__(self):
super().__init__()
|
@@ -3427,27 +3421,27 @@
Source code in pytorch_widedeep/losses.py
- 459
-460
-461
-462
-463
-464
-465
-466
-467
-468
-469
-470
-471
-472
-473
-474
-475
+ | def forward(self, input: Tensor, target: Tensor) -> Tensor:
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
| def forward(self, input: Tensor, target: Tensor) -> Tensor:
r"""
Parameters
----------
@@ -3541,16 +3535,16 @@
Source code in pytorch_widedeep/losses.py
- 499
-500
-501
-502
-503
-504
-505
-506
-507
-508 | def __init__(
+ 515
+516
+517
+518
+519
+520
+521
+522
+523
+524 | def __init__(
self,
beta: float = 0.2,
gamma: float = 1.0,
@@ -3623,23 +3617,7 @@
Source code in pytorch_widedeep/losses.py
- 510
-511
-512
-513
-514
-515
-516
-517
-518
-519
-520
-521
-522
-523
-524
-525
-526
+ 526
527
528
529
@@ -3657,7 +3635,23 @@
541
542
543
-544 | def forward(
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
| def forward(
self,
input: Tensor,
target: Tensor,
@@ -3767,16 +3761,16 @@
Source code in pytorch_widedeep/losses.py
- 564
-565
-566
-567
-568
-569
-570
-571
-572
-573 | def __init__(
+ 580
+581
+582
+583
+584
+585
+586
+587
+588
+589 | def __init__(
self,
beta: float = 0.2,
gamma: float = 1.0,
@@ -3849,23 +3843,7 @@
Source code in pytorch_widedeep/losses.py
- 575
-576
-577
-578
-579
-580
-581
-582
-583
-584
-585
-586
-587
-588
-589
-590
-591
+ 591
592
593
594
@@ -3883,7 +3861,23 @@
606
607
608
-609 | def forward(
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
+623
+624
+625
| def forward(
self,
input: Tensor,
target: Tensor,
@@ -3993,16 +3987,16 @@
Source code in pytorch_widedeep/losses.py
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|