Skip to content

Commit c3fa5bc

Browse files
authored
Merge pull request #32 from hand10ryo/develop
Develop
2 parents f653b9f + c1ca9ba commit c3fa5bc

35 files changed

+1693
-694
lines changed

PyTorchCML/adaptors/BaseAdaptor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from torch import nn
2+
3+
4+
class BaseAdaptor(nn.Module):
5+
"""Astract class of module for domain adaptation."""
6+
7+
def __init__(self, weight):
8+
"""Set some parameters
9+
10+
Args:
11+
weight (float, optional): Loss weights for domain adaptation. Defaults to 1e-3.
12+
"""
13+
super().__init__()
14+
self.weight = weight
15+
16+
def forward(self, indices, embeddings):
17+
"""Method to calculate loss for domain adaptation.
18+
19+
Args:
20+
indices (torch.Tensor): Indices of users or items. size = (n_user, n_sample)
21+
embeddings (torch.Tensor): The embeddings corresponding to indices. size = (n_user, n_sample, n_dim)
22+
23+
Raises:
24+
NotImplementedError: [description]
25+
"""
26+
raise NotImplementedError

PyTorchCML/adaptors/MLPAdaptor.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
from torch import nn
4+
5+
from .BaseAdaptor import BaseAdaptor
6+
7+
8+
class MLPAdaptor(BaseAdaptor):
9+
"""Class of module for domain adaptation with MLP."""
10+
11+
def __init__(
12+
self,
13+
features: torch.Tensor,
14+
n_dim: int = 20,
15+
n_hidden: list = [100],
16+
weight: float = 1e-3,
17+
):
18+
"""Set MLP model for domain adaptation.
19+
20+
Args:
21+
features (torch.Tensor): A feature of users or items. size = (n_user, n_feature)
22+
n_dim (int, optional): A number of dimention of embeddings. Defaults to 20.
23+
n_hidden (list, optional): A list of numbers of neuron for each hidden layers. Defaults to [100].
24+
weight (float, optional): Loss weights for domain adaptation. Defaults to 1e-3.
25+
"""
26+
super().__init__(weight)
27+
self.features_embeddings = nn.Embedding.from_pretrained(features)
28+
self.features_embeddings.weight.requires_grad = False
29+
30+
self.n_input = features.shape[1]
31+
self.n_hidden = n_hidden
32+
self.n_output = n_dim
33+
34+
projection_layers = [nn.Linear(self.n_input, self.n_hidden[0]), nn.ReLU()]
35+
for i in range(len(self.n_hidden) - 1):
36+
layer = [nn.Linear(self.n_hidden[i], self.n_hidden[i + 1]), nn.ReLU()]
37+
projection_layers += layer
38+
projection_layers += [nn.Linear(self.n_hidden[-1], self.n_output)]
39+
40+
self.projector = nn.Sequential(*projection_layers)
41+
42+
def forward(self, indices: torch.Tensor, embeddings: torch.Tensor):
43+
"""Method to calculate loss for domain adaptation.
44+
45+
Args:
46+
indices (torch.Tensor): Indices of users or items. size = (n_user, n_sample)
47+
embeddings (torch.Tensor): The embeddings corresponding to indices. size = (n_user, n_sample, n_dim)
48+
49+
Returns:
50+
[torch.Tensor]: loss for domain adaptation. dim = 0.
51+
"""
52+
features = self.features_embeddings(indices)
53+
projection = self.projector(features)
54+
dist = torch.sqrt(torch.pow(projection - embeddings, 2).sum(axis=2))
55+
return self.weight * dist.sum()

PyTorchCML/adaptors/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
from .BaseAdaptor import BaseAdaptor
3+
from .MLPAdaptor import MLPAdaptor

PyTorchCML/losses/BaseLoss.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class BaseLoss(nn.Module):
6+
"""Class of abstract loss module for pairwise loss like matrix factorization."""
7+
8+
def __init__(self, regularizers: list = []):
9+
super().__init__()
10+
self.regularizers = regularizers
11+
12+
def forward(
13+
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
14+
) -> torch.Tensor:
15+
loss = self.main(embeddings_dict, batch, column_names)
16+
loss += self.regularize(embeddings_dict)
17+
return loss
18+
19+
def main(
20+
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
21+
) -> torch.Tensor:
22+
"""
23+
Args:
24+
embeddings_dict (dict): A dictionary of embddings.
25+
(e.g. It has following key and values.)
26+
user_embedding : embeddings of user, size (n_batch, 1, d)
27+
pos_item_embedding : embeddings of positive item, size (n_batch, 1, d)
28+
neg_item_embedding : embeddings of negative item, size (n_batch, n_neg_samples, d)
29+
user_bias : bias of user, size (n_batch, 1)
30+
pos_item_bias : bias of positive item, size (n_batch, 1)
31+
neg_item_bias : bias of negative item, size (n_batch, n_neg_samples)
32+
33+
batch (torch.Tensor) : A tensor of batch, size (n_batch, *).
34+
column_names (dict) : A dictionary that maps names to indices of rows of batch.
35+
36+
Raises:
37+
NotImplementedError: [description]
38+
39+
Returns:
40+
torch.Tensor: [description]
41+
42+
--- example code ---
43+
44+
embeddings_dict = {
45+
"user_embedding": user_embedding,
46+
"pos_item_embedding": pos_item_embedding,
47+
"neg_item_embedding": neg_item_embedding,
48+
"user_bias": user_bias,
49+
"pos_item_bias": pos_item_bias,
50+
"neg_item_bias": neg_item_bias,
51+
}
52+
53+
loss = loss_function(embeddings_dict, batch, column_names)
54+
55+
return loss
56+
"""
57+
58+
raise NotImplementedError
59+
60+
def regularize(self, embeddings_dict: dict):
61+
reg = 0
62+
for regularizer in self.regularizers:
63+
reg += regularizer(embeddings_dict)
64+
65+
return reg

PyTorchCML/losses/BasePairwiseLoss.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

PyTorchCML/losses/BaseTripletLoss.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

PyTorchCML/losses/LogitPairwiseLoss.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
import torch
22
from torch import nn
33

4-
from .BasePairwiseLoss import BasePairwiseLoss
4+
from .BaseLoss import BaseLoss
55

66

7-
class LogitPairwiseLoss(BasePairwiseLoss):
7+
class LogitPairwiseLoss(BaseLoss):
88
"""Class of pairwise logit loss for Logistic Matrix Factorization"""
99

1010
def __init__(self, regularizers: list = []):
1111
super().__init__(regularizers)
1212
self.LogSigmoid = nn.LogSigmoid()
1313

14-
def forward(
14+
def main(
1515
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1616
) -> torch.Tensor:
17-
"""Method of forwarding loss
17+
"""Method of forwarding main loss
1818
1919
Args:
2020
embeddings_dict (dict): A dictionary of embddings which has following key and values.
21-
user_embedding : embeddings of user, size (n_batch, d)
22-
pos_item_embedding : embeddings of positive item, size (n_batch, d)
21+
user_embedding : embeddings of user, size (n_batch, 1, d)
22+
pos_item_embedding : embeddings of positive item, size (n_batch, 1, d)
2323
neg_item_embedding : embeddings of negative item, size (n_batch, n_neg_samples, d)
2424
user_bias : bias of user, size (n_batch, 1)
2525
pos_item_bias : bias of positive item, size (n_batch, 1)
@@ -34,13 +34,13 @@ def forward(
3434
n_pos = 1
3535

3636
pos_inner = torch.einsum(
37-
"nd,nd->n",
37+
"nid,nid->n",
3838
embeddings_dict["user_embedding"],
3939
embeddings_dict["pos_item_embedding"],
4040
)
4141

4242
neg_inner = torch.einsum(
43-
"nd,njd->nj",
43+
"nid,njd->nj",
4444
embeddings_dict["user_embedding"],
4545
embeddings_dict["neg_item_embedding"],
4646
)
@@ -55,6 +55,5 @@ def forward(
5555
neg_loss = -nn.LogSigmoid()(-neg_y_hat).sum()
5656

5757
loss = (pos_loss + neg_loss) / (n_batch * (n_pos + n_neg))
58-
reg = self.regularize(embeddings_dict)
5958

60-
return loss + reg
59+
return loss

PyTorchCML/losses/MSEPairwiseLoss.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
import torch
22

3-
from .BasePairwiseLoss import BasePairwiseLoss
3+
from .BaseLoss import BaseLoss
44

55

6-
class MSEPairwiseLoss(BasePairwiseLoss):
6+
class MSEPairwiseLoss(BaseLoss):
77
"""Class of loss for MSE in implicit feedback"""
88

9-
def forward(
9+
def main(
1010
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1111
) -> torch.Tensor:
12-
"""Method of forwarding loss
12+
"""Method of forwarding main loss
1313
1414
Args:
1515
embeddings_dict (dict): A dictionary of embddings which has following key and values.
16-
"user_embedding" : embeddings of user, size (n_batch, d)
17-
"pos_item_embedding" : embeddings of positive item, size (n_batch, d)
16+
"user_embedding" : embeddings of user, size (n_batch, 1, d)
17+
"pos_item_embedding" : embeddings of positive item, size (n_batch, 1, d)
1818
"neg_item_embedding" : embeddings of negative item, size (n_batch, n_neg_samples, d)
1919
"user_bias" : bias of user, size (n_batch, 1)
2020
"pos_item_bias" : bias of positive item, size (n_batch, 1)
@@ -32,13 +32,13 @@ def forward(
3232
n_pos = 1
3333

3434
pos_inner = torch.einsum(
35-
"nd,nd->n",
35+
"nid,nid->n",
3636
embeddings_dict["user_embedding"],
3737
embeddings_dict["pos_item_embedding"],
3838
)
3939

4040
neg_inner = torch.einsum(
41-
"nd,njd->nj",
41+
"nid,njd->nj",
4242
embeddings_dict["user_embedding"],
4343
embeddings_dict["neg_item_embedding"],
4444
)
@@ -53,6 +53,5 @@ def forward(
5353
neg_loss = (torch.sigmoid(neg_r_hat) ** 2).sum()
5454

5555
loss = (pos_loss + neg_loss) / (n_batch * (n_pos + n_neg))
56-
reg = self.regularize(embeddings_dict)
5756

58-
return loss + reg
57+
return loss

0 commit comments

Comments
 (0)