-
Notifications
You must be signed in to change notification settings - Fork 33
Description
Dear professor,I have some questions about the 3DInfomax.
I want to get the evaluation metrics such as Precision,so I use the Function which you provided in your metric.py such as TruePositiveRate() and TrueNegativeRate() to get this metric. But I tried all OGB datasets and found that those metrics such as Precision,Accuracy and Recall were not ideal. I hope you can reply to me as soon as possible. Thank you, professor.
Here is the HIV dataset's metric:
Precision: 0.008995866402983665
Accuracy: 0.9988852739334106
Recall: 0.002496626228094101
F1_score: 0.003908519633114338
ROC_AUC: 0.7427065372467041
PR_AUC: 0.2141391634941101
ogbg-molhiv: 0.742706502636204
BCEWithLogitsLoss: 0.17792926660992883
Here is the BBBP dataset's metric:
Precision: 0.44607841968536377
Accuracy: 0.6127931475639343
Recall: 0.005654983688145876
F1_score: 0.011168383993208408
ROC_AUC: 0.6745756268501282
PR_AUC: 0.6546612977981567
ogbg-molbbbp: 0.6745756172839505
BCEWithLogitsLoss: 1.1453146849359785
Here is my metric code:
class Precision(nn.Module):
def init(self, threshold=0.5) -> None:
super(Precision, self).init()
self.threshold = threshold
def forward(self, x1: Tensor, x2: Tensor, pos_mask: Tensor = None) -> Tensor:
batch_size, _ = x1.size()
if x1.shape != x2.shape and pos_mask == None:
x2 = x2[:batch_size]
sim_matrix = torch.einsum('ik,jk->ij', x1, x2)
x1_abs = x1.norm(dim=1)
x2_abs = x2.norm(dim=1)
sim_matrix = sim_matrix / torch.einsum('i,j->ij', x1_abs, x2_abs)
preds: Tensor = (sim_matrix + 1) / 2 > self.threshold
if pos_mask == None: # if we are comparing global with global
pos_mask = torch.eye(batch_size, device=x1.device)
neg_mask = 1 - pos_mask
num_positives = len(x1)
num_negatives = len(x1) * (len(x2) - 1)
false_positives = ((preds.long() - pos_mask) * pos_mask).count_nonzero()
true_positives = num_positives - ((preds.long() - pos_mask) * pos_mask).count_nonzero()
false_negatives = (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
true_negatives = num_negatives - (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
pre = true_positives /(true_positives + false_positives)
return pre