Skip to content

How to use GradCache in non-single input function?  #22

@lxx909546478

Description

@lxx909546478

Great work!
I find it works well for X and Y with its own encoder, but for some reason, I have to use the setting:
X and Y is with the same shape, X_i and Y_i is the positive sample, X_i and all Y_js are negative samples, X and Y are fed into the same network(as below). The loss function is contrastive_loss.

class A(nn.Module):
    ...
    def forward(self, X, Y):
        embx = fx(X)
        emby = fy(Y)
        return embx, emby
    

I find the example you give is like:

cache_x = []
cache_y = []
closures_x = []
closures_y = []

for step, sub_batch in enumerate(loader):  
    xx, yy = sub_batch
    rx, cx = call_model(bert, xx)
    ry, cy = call_model(bert, yy)
    
    cache_x.append(rx)
    cache_y.append(ry)
    closuresx.append(cx)
    closuresy.append(cy)
    
    if (step + 1) % 16 == 0:
        loss = contrastive_loss(cache_x, cache_y)
        scaler.scale(loss).backward()
        
	for f, r in zip(closuresx, cache_x):
            f(r)
        for f, r in zip(closuresy, cache_y):
            f(r)

        cache_x = []
        cache_y = []
        closures_x = []
        closures_y = []
	
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

How could I use GradCache in this setting? Should I store two RandContext for X and Y?
Looking forward to your help.
Thanks a lot~

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions