-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.0.x
Description
Bug description
I followed this tutorial to build a lightning model for multi-label text classification, except I'm using my own dataset.
I had to fix some of the code because it was using deprecated syntax/features I believe.
However when I run Trainer.fit
I get an infinite hang with no previous meaningful error.
As I'm not sure how to go about debugging this I'm creating this issue.
I'm running lightning 2.0.8, python 3.8.10, cuda 11.7 and nvidia driver 470.199.02.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
This is my custom defined module for training the model.
class Tagger(pl.LightningModule):
def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model_name, return_dict=True)
self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
self.n_training_steps = n_training_steps
self.n_warmup_steps = n_warmup_steps
self.critireion = nn.BCELoss()
def forward(self, input_ids, attention_mask, labels=None):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.pooler_output)
output = torch.sigmoid(output)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('train_loss', loss, prog_bar=True, logger=True)
return {'loss': loss, 'predictions': outputs, 'labels': labels}
def validation_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('val_loss', loss, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('test_loss', loss, prog_bar=True, logger=True)
return loss
def on_train_epoch_end(self, outputs):
labels = []
predictions = []
for output in outputs:
for out_labels in output['labels'].detach.cpu():
labels.append(out_labels)
for out_predictions in output['predictions'].detach().cpu():
predictions.append(out_predictions)
labels = torch.stack(labels).int()
predictions = torch.stack(predictions)
for i, name in enumerate(label_columns):
class_roc_auc = auroc(predictions[:, i], labels[:, i])
self.logger.experiment.add_scalar(f'{name}_roc_auc/Train', class_roc_auc, self.current_epoch)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=1e-5)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.n_warmup_steps,
num_training_steps=self.n_training_steps
)
return dict(optimizer=optimizer, lr_scheduler=scheduler, interval='step')
Error messages and logs
This is what I get before it hangs.
[rank: 0] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
Missing logger folder: model/saved/lightning_logs
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
Missing logger folder: model/saved/lightning_logs
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.0.x