Skip to content

Commit 7b17e4d

Browse files
committed
add metrics
1 parent 4b4b30a commit 7b17e4d

File tree

3 files changed

+102
-9
lines changed

3 files changed

+102
-9
lines changed

torchdrug/metrics/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \
2-
variadic_accuracy
2+
accuracy, matthews_corrcoef, pearsonr, spearmanr, variadic_accuracy
33

44
# alias
55
AUROC = area_under_roc
66
AUPRC = area_under_prc
77

88
__all__ = [
99
"area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity",
10+
"accuracy", "matthews_corrcoef", "pearsonr", "spearmanr",
1011
"variadic_accuracy",
1112
"AUROC", "AUPRC",
12-
]
13+
]

torchdrug/metrics/metric.py

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import os
2-
import sys
3-
41
import torch
52
from torch.nn import functional as F
6-
from torch_scatter import scatter_max
3+
from torch_scatter import scatter_add, scatter_mean, scatter_max
74
import networkx as nx
85
from rdkit import Chem
9-
from rdkit.Chem import RDConfig, Descriptors
6+
from rdkit.Chem import Descriptors
107

118
from torchdrug import utils
129
from torchdrug.layers import functional
@@ -23,6 +20,8 @@ def area_under_roc(pred, target):
2320
pred (Tensor): predictions of shape :math:`(n,)`
2421
target (Tensor): binary targets of shape :math:`(n,)`
2522
"""
23+
if target.dtype != torch.long:
24+
raise TypeError("Expect `target` to be torch.long, but found %s" % target.dtype)
2625
order = pred.argsort(descending=True)
2726
target = target[order]
2827
hit = target.cumsum(0)
@@ -40,6 +39,8 @@ def area_under_prc(pred, target):
4039
pred (Tensor): predictions of shape :math:`(n,)`
4140
target (Tensor): binary targets of shape :math:`(n,)`
4241
"""
42+
if target.dtype != torch.long:
43+
raise TypeError("Expect `target` to be torch.long, but found %s" % target.dtype)
4344
order = pred.argsort(descending=True)
4445
target = target[order]
4546
precision = target.cumsum(0) / torch.arange(1, len(target) + 1, device=target.device)
@@ -178,13 +179,103 @@ def chemical_validity(pred):
178179
validity.append(1 if mol else 0)
179180

180181
return torch.tensor(validity, dtype=torch.float, device=pred.device)
182+
183+
184+
@R.register("metrics.accuracy")
185+
def accuracy(pred, target):
186+
"""
187+
Compute classification accuracy over sets with equal size.
188+
189+
Suppose there are :math:`N` sets and :math:`C` categories.
190+
191+
Parameters:
192+
pred (Tensor): prediction of shape :math:`(N, C)`
193+
target (Tensor): target of shape :math:`(N,)`
194+
"""
195+
return (pred.argmax(dim=-1) == target).float().mean()
196+
197+
198+
@R.register("metrics.mcc")
199+
def matthews_corrcoef(pred, target, eps=1e-6):
200+
"""
201+
Matthews correlation coefficient between target and prediction.
202+
203+
Definition follows matthews_corrcoef for K classes in sklearn.
204+
For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
205+
206+
Parameters:
207+
pred (Tensor): prediction of shape :math: `(N,)`
208+
target (Tensor): target of shape :math: `(N,)`
209+
"""
210+
num_class = pred.size(-1)
211+
pred = pred.argmax(-1)
212+
ones = torch.ones(len(target), device=pred.device)
213+
confusion_matrix = scatter_add(ones, target * num_class + pred, dim=0, dim_size=num_class ** 2)
214+
confusion_matrix = confusion_matrix.view(num_class, num_class)
215+
t = confusion_matrix.sum(dim=1)
216+
p = confusion_matrix.sum(dim=0)
217+
c = confusion_matrix.trace()
218+
s = confusion_matrix.sum()
219+
return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + eps).sqrt()
220+
221+
222+
@R.register("metrics.pearsonr")
223+
def pearsonr(pred, target):
224+
"""
225+
Pearson correlation between target and prediction.
226+
Mimics `scipy.stats.pearsonr`.
227+
228+
Parameters:
229+
pred (Tensor): prediction of shape :math: `(N,)`
230+
target (Tensor): target of shape :math: `(N,)`
231+
"""
232+
pred_mean = pred.float().mean()
233+
target_mean = target.float().mean()
234+
pred_centered = pred - pred_mean
235+
target_centered = target - target_mean
236+
pred_normalized = pred_centered / pred_centered.norm(2)
237+
target_normalized = target_centered / target_centered.norm(2)
238+
pearsonr = pred_normalized @ target_normalized
239+
return pearsonr
240+
241+
242+
@R.register("metrics.spearmanr")
243+
def spearmanr(pred, target, eps=1e-6):
244+
"""
245+
Spearman correlation between target and prediction.
246+
Implement in PyTorch, but non-diffierentiable. (validation metric only)
247+
248+
Parameters:
249+
pred (Tensor): prediction of shape :math: `(N,)`
250+
target (Tensor): target of shape :math: `(N,)`
251+
"""
252+
253+
def get_ranking(input):
254+
input_set, input_inverse = input.unique(return_inverse=True)
255+
order = input_inverse.argsort()
256+
ranking = torch.zeros(len(input_inverse), device=input.device)
257+
ranking[order] = torch.arange(1, len(input) + 1, dtype=torch.float, device=input.device)
258+
259+
# for elements that have the same value, replace their rankings with the mean of their rankings
260+
mean_ranking = scatter_mean(ranking, input_inverse, dim=0, dim_size=len(input_set))
261+
ranking = mean_ranking[input_inverse]
262+
return ranking
263+
264+
pred = get_ranking(pred)
265+
target = get_ranking(target)
266+
covariance = (pred * target).mean() - pred.mean() * target.mean()
267+
pred_std = pred.std(unbiased=False)
268+
target_std = target.std(unbiased=False)
269+
spearmanr = covariance / (pred_std * target_std + eps)
270+
return spearmanr
181271

182272

273+
@R.register("metrics.variadic_accuracy")
183274
def variadic_accuracy(input, target, size):
184275
"""
185276
Compute classification accuracy over variadic sizes of categories.
186277
187-
Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math`B`.
278+
Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`.
188279
189280
Parameters:
190281
input (Tensor): prediction of shape :math:`(B,)`
@@ -196,4 +287,4 @@ def variadic_accuracy(input, target, size):
196287
input_class = scatter_max(input, index2graph)[1]
197288
target_index = target + size.cumsum(0) - size
198289
accuracy = (input_class == target_index).float()
199-
return accuracy
290+
return accuracy

torchdrug/tasks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"mse": "mean squared error",
2020
"rmse": "root mean squared error",
2121
"acc": "accuracy",
22+
"mcc": "matthews correlation coefficient",
2223
}
2324

2425

0 commit comments

Comments
 (0)