Skip to content

Inplement about the contrastive loss  #4

@Adasunnylily

Description

@Adasunnylily

hi, i am confused about how to compute the contrastive loss in the paper, as it mentioned in the paper to calculate lret through (t,t',y,y') , but in the code, the model returns (cpt - 0.5) * 2, cls, attn, updates and seems to calculate loss by directly seeding (cpt - 0.5) * 2 as y and use the pairwise_loss as lret, but the input for the pairwise_loss is pairwise_loss(y, y, label, label in the released code, which means they are similar all the time and its really confusing.
Could you please tell me if there is anything wrong with my understanding?
Really appreciate to your help!

def get_retrieval_loss(args, y, label, num_cls, device):
    b = label.shape[0]
    if args.dataset != "matplot":
        label = label.unsqueeze(-1)
        label = torch.zeros(b, num_cls).to(device).scatter(1, label, 1)
    similarity_loss = pairwise_loss(y, y, label, label, sigmoid_param=10. / 32)
    # similarity_loss = pairwise_loss2(y, y, label.float(), sigmoid_param=10. / 32)
    q_loss = quantization_loss(y)
    return similarity_loss, q_loss

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions