1
- import os
2
- import sys
3
-
4
1
import torch
5
2
from torch .nn import functional as F
6
- from torch_scatter import scatter_max
3
+ from torch_scatter import scatter_add , scatter_mean , scatter_max
7
4
import networkx as nx
8
5
from rdkit import Chem
9
- from rdkit .Chem import RDConfig , Descriptors
6
+ from rdkit .Chem import Descriptors
10
7
11
8
from torchdrug import utils
12
9
from torchdrug .layers import functional
@@ -23,6 +20,8 @@ def area_under_roc(pred, target):
23
20
pred (Tensor): predictions of shape :math:`(n,)`
24
21
target (Tensor): binary targets of shape :math:`(n,)`
25
22
"""
23
+ if target .dtype != torch .long :
24
+ raise TypeError ("Expect `target` to be torch.long, but found %s" % target .dtype )
26
25
order = pred .argsort (descending = True )
27
26
target = target [order ]
28
27
hit = target .cumsum (0 )
@@ -40,6 +39,8 @@ def area_under_prc(pred, target):
40
39
pred (Tensor): predictions of shape :math:`(n,)`
41
40
target (Tensor): binary targets of shape :math:`(n,)`
42
41
"""
42
+ if target .dtype != torch .long :
43
+ raise TypeError ("Expect `target` to be torch.long, but found %s" % target .dtype )
43
44
order = pred .argsort (descending = True )
44
45
target = target [order ]
45
46
precision = target .cumsum (0 ) / torch .arange (1 , len (target ) + 1 , device = target .device )
@@ -178,13 +179,103 @@ def chemical_validity(pred):
178
179
validity .append (1 if mol else 0 )
179
180
180
181
return torch .tensor (validity , dtype = torch .float , device = pred .device )
182
+
183
+
184
+ @R .register ("metrics.accuracy" )
185
+ def accuracy (pred , target ):
186
+ """
187
+ Compute classification accuracy over sets with equal size.
188
+
189
+ Suppose there are :math:`N` sets and :math:`C` categories.
190
+
191
+ Parameters:
192
+ pred (Tensor): prediction of shape :math:`(N, C)`
193
+ target (Tensor): target of shape :math:`(N,)`
194
+ """
195
+ return (pred .argmax (dim = - 1 ) == target ).float ().mean ()
196
+
197
+
198
+ @R .register ("metrics.mcc" )
199
+ def matthews_corrcoef (pred , target , eps = 1e-6 ):
200
+ """
201
+ Matthews correlation coefficient between target and prediction.
202
+
203
+ Definition follows matthews_corrcoef for K classes in sklearn.
204
+ For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
205
+
206
+ Parameters:
207
+ pred (Tensor): prediction of shape :math: `(N,)`
208
+ target (Tensor): target of shape :math: `(N,)`
209
+ """
210
+ num_class = pred .size (- 1 )
211
+ pred = pred .argmax (- 1 )
212
+ ones = torch .ones (len (target ), device = pred .device )
213
+ confusion_matrix = scatter_add (ones , target * num_class + pred , dim = 0 , dim_size = num_class ** 2 )
214
+ confusion_matrix = confusion_matrix .view (num_class , num_class )
215
+ t = confusion_matrix .sum (dim = 1 )
216
+ p = confusion_matrix .sum (dim = 0 )
217
+ c = confusion_matrix .trace ()
218
+ s = confusion_matrix .sum ()
219
+ return (c * s - t @ p ) / ((s * s - p @ p ) * (s * s - t @ t ) + eps ).sqrt ()
220
+
221
+
222
+ @R .register ("metrics.pearsonr" )
223
+ def pearsonr (pred , target ):
224
+ """
225
+ Pearson correlation between target and prediction.
226
+ Mimics `scipy.stats.pearsonr`.
227
+
228
+ Parameters:
229
+ pred (Tensor): prediction of shape :math: `(N,)`
230
+ target (Tensor): target of shape :math: `(N,)`
231
+ """
232
+ pred_mean = pred .float ().mean ()
233
+ target_mean = target .float ().mean ()
234
+ pred_centered = pred - pred_mean
235
+ target_centered = target - target_mean
236
+ pred_normalized = pred_centered / pred_centered .norm (2 )
237
+ target_normalized = target_centered / target_centered .norm (2 )
238
+ pearsonr = pred_normalized @ target_normalized
239
+ return pearsonr
240
+
241
+
242
+ @R .register ("metrics.spearmanr" )
243
+ def spearmanr (pred , target , eps = 1e-6 ):
244
+ """
245
+ Spearman correlation between target and prediction.
246
+ Implement in PyTorch, but non-diffierentiable. (validation metric only)
247
+
248
+ Parameters:
249
+ pred (Tensor): prediction of shape :math: `(N,)`
250
+ target (Tensor): target of shape :math: `(N,)`
251
+ """
252
+
253
+ def get_ranking (input ):
254
+ input_set , input_inverse = input .unique (return_inverse = True )
255
+ order = input_inverse .argsort ()
256
+ ranking = torch .zeros (len (input_inverse ), device = input .device )
257
+ ranking [order ] = torch .arange (1 , len (input ) + 1 , dtype = torch .float , device = input .device )
258
+
259
+ # for elements that have the same value, replace their rankings with the mean of their rankings
260
+ mean_ranking = scatter_mean (ranking , input_inverse , dim = 0 , dim_size = len (input_set ))
261
+ ranking = mean_ranking [input_inverse ]
262
+ return ranking
263
+
264
+ pred = get_ranking (pred )
265
+ target = get_ranking (target )
266
+ covariance = (pred * target ).mean () - pred .mean () * target .mean ()
267
+ pred_std = pred .std (unbiased = False )
268
+ target_std = target .std (unbiased = False )
269
+ spearmanr = covariance / (pred_std * target_std + eps )
270
+ return spearmanr
181
271
182
272
273
+ @R .register ("metrics.variadic_accuracy" )
183
274
def variadic_accuracy (input , target , size ):
184
275
"""
185
276
Compute classification accuracy over variadic sizes of categories.
186
277
187
- Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math`B`.
278
+ Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math: `B`.
188
279
189
280
Parameters:
190
281
input (Tensor): prediction of shape :math:`(B,)`
@@ -196,4 +287,4 @@ def variadic_accuracy(input, target, size):
196
287
input_class = scatter_max (input , index2graph )[1 ]
197
288
target_index = target + size .cumsum (0 ) - size
198
289
accuracy = (input_class == target_index ).float ()
199
- return accuracy
290
+ return accuracy
0 commit comments