Skip to content

NaN's and inf loss seen during training #73

@karanchahal-nv

Description

@karanchahal-nv

Hello !

Thanks for your very nice package, I have successfully applied to estimate the density of a distribution trained on X with dimension 3. It worked very well.

Now, I am scaling it up for a distribution X that has 19 dimensions. I am using this model architecture.

def get_model():
    D = 19
    base = nf.distributions.base.DiagGaussian(D)

    # define a fixed half/half mask: first 10 condition, last 9 transformed
    cond_size = (D + 1) // 2  # 10
    trans_size = D - cond_size # 9
    mask = torch.cat([
        torch.ones(cond_size, dtype=torch.bool),
        torch.zeros(trans_size, dtype=torch.bool)
    ])

    num_layers = 32
    flows = []
    for _ in range(num_layers):
        # conditioner must output 2 * trans_size (shift + log_scale)
        param_map = nf.nets.MLP([cond_size, 64, 64, 2 * trans_size], init_zeros=True)
        flows.append(nf.flows.AffineCouplingBlock(param_map))
        # swap halves so the other side is transformed next time
        flows.append(nf.flows.Permute(D, mode='swap'))

    model = nf.NormalizingFlow(base, flows)

    enable_cuda = True
    device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
    model = model.to(device)
    return model

I am using this optimizer and learning rate.

def main():
    model = get_model()
    num_epochs = 4000
    lr = 5e-5
    wd = 1e-5
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    train(
        model=model,
        dataloader=get_dataset(batch_size=1),
        num_epochs=num_epochs,
        optimizer=optimizer,
        tensorboard_writer=get_tensorboard_writer()
    )

and this is how I step the model

def step_model_with_batch(
    model,
    x,
    optimizer
) -> float:
    optimizer.zero_grad()
    loss = -model.log_prob(x).mean()
    if torch.isfinite(loss):
        loss.backward()
        optimizer.step()
    return float(loss.detach().cpu())

But I am seeing this sort of loss

Step 0, Loss: 25.38446807861328                                                                                                                                                                             
Step 50, Loss: 22.443069458007812                                                                                                                                                                           
Step 100, Loss: nan                                                                                                                                                                                         
Step 150, Loss: 21.203521728515625                                                                                                                                                                          
Step 200, Loss: 21.044057846069336                                                                                                                                                                          
Step 250, Loss: 20.569766998291016                                                                                                                                                                          
Step 300, Loss: 20.03892707824707                                                                                                                                                                           
Step 350, Loss: nan                                                                                                                                                                                         
Step 400, Loss: 19.617673873901367  
...
...
Step 151600, Loss: 7.707192897796631                                                                                                                                                                        
Step 151650, Loss: 7.9374237060546875                                                                                                                                                                       
Step 151700, Loss: 7.607234954833984                                                                                                                                                                        
Step 151750, Loss: 7.97566032409668                                                                                                                                                                         
Step 151800, Loss: 7.6797285079956055                                                                                                                                                                       
Step 151850, Loss: 157298064.0                                                                                                                                                                              
Step 151900, Loss: 8.694652557373047                                                                                                                                                                        
Step 151950, Loss: 7.95250129699707                                                                                                                                                                         
     

Is this common in normalizing flow networks ? Or any nan or loss blowing up should be treated seriously ?

How should I start to debug this ? Am I doing something wrong ?

Thanks again !

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions