-
Notifications
You must be signed in to change notification settings - Fork 26
Description
Hi, it's a great work!
We have three inputs designated as i1
, i2
, and i3
, which are to be processed by the llama-7b. For input i1
, I will extract two hidden states at two distinct locations and label them p11
and p12
, respectively. Regarding the remaining inputs, i2
and i3
, I will select a single hidden state for each, which will be denoted as n21
and n31
correspondingly.
In this setup, p11
paired with n21
constitutes a positive pair, whereas p11
coupled with n22
forms a negative pair. Meanwhile, p12
paired with n22
constitutes a positive pair, whereas p12
coupled with n21
forms a negative pair. My objective is to compute the InfoNCE loss between these pairs.
So I set the get_rep_fn
in the class GradCache
to handle the different situations. Here is a sample snippet or a piece of example code:
def get_rep_fn(x):
if x.label == 2:
return [x.e1, x.e2]
else:
return [x.e1]
In the same time, I changed the following code from append
to extend
:
GradCache/src/grad_cache/grad_cache.py
Line 187 in 0c33638
model_reps.append(self.get_reps(y)) |
GradCache/src/grad_cache/grad_cache.py
Line 270 in 0c33638
all_reps.append(model_reps) |
I'd like to inquire about the correctness of the gradient computation. Could you please confirm if it's being done accurately?
Thanks!