-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
hi, i am confused about how to compute the contrastive loss in the paper, as it mentioned in the paper to calculate lret through (t,t',y,y') , but in the code, the model returns (cpt - 0.5) * 2, cls, attn, updates and seems to calculate loss by directly seeding (cpt - 0.5) * 2 as y and use the pairwise_loss as lret, but the input for the pairwise_loss is pairwise_loss(y, y, label, label in the released code, which means they are similar all the time and its really confusing.
Could you please tell me if there is anything wrong with my understanding?
Really appreciate to your help!
def get_retrieval_loss(args, y, label, num_cls, device):
b = label.shape[0]
if args.dataset != "matplot":
label = label.unsqueeze(-1)
label = torch.zeros(b, num_cls).to(device).scatter(1, label, 1)
similarity_loss = pairwise_loss(y, y, label, label, sigmoid_param=10. / 32)
# similarity_loss = pairwise_loss2(y, y, label.float(), sigmoid_param=10. / 32)
q_loss = quantization_loss(y)
return similarity_loss, q_loss
Metadata
Metadata
Assignees
Labels
No labels