Skip to content

Commit 94bf3f4

Browse files
authored
Merge pull request #59 from discovery-unicamp/setr-logging-fix
Setr logging fix
2 parents cb7de07 + 9424c0b commit 94bf3f4

File tree

2 files changed

+40
-117
lines changed

2 files changed

+40
-117
lines changed

minerva/models/nets/setr.py

Lines changed: 39 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import Optional, Tuple
2+
from typing import Dict, Optional, Tuple
33

44
import lightning as L
55
import torch
66
from torch import nn
7-
from torchmetrics import JaccardIndex
7+
from torchmetrics import JaccardIndex, Metric
88

99
from minerva.models.nets.vit import _VisionTransformerBackbone
1010
from minerva.utils.upsample import Upsample, resize
@@ -407,12 +407,9 @@ def __init__(
407407
conv_act: Optional[nn.Module] = None,
408408
interpolate_mode: str = "bilinear",
409409
loss_fn: Optional[nn.Module] = None,
410-
log_train_metrics: bool = False,
411-
train_metrics: Optional[nn.Module] = None,
412-
log_val_metrics: bool = False,
413-
val_metrics: Optional[nn.Module] = None,
414-
log_test_metrics: bool = False,
415-
test_metrics: Optional[nn.Module] = None,
410+
train_metrics: Optional[Dict[str, Metric]] = None,
411+
val_metrics: Optional[Dict[str, Metric]] = None,
412+
test_metrics: Optional[Dict[str, Metric]] = None,
416413
aux_output: bool = True,
417414
aux_output_layers: list[int] | None = [9, 14, 19],
418415
aux_weights: list[float] = [0.3, 0.3, 0.3],
@@ -460,10 +457,18 @@ def __init__(
460457
The interpolation mode for upsampling in the decoder. Defaults to "bilinear".
461458
loss_fn : nn.Module, optional
462459
The loss function to be used during training. Defaults to None.
463-
log_metrics : bool
464-
Whether to log metrics during training. Defaults to True.
465-
metrics : list[MetricTypeSetR], optional
466-
The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU].
460+
train_metrics : Dict[str, Metric], optional
461+
The metrics to be used for training evaluation. Defaults to None.
462+
val_metrics : Dict[str, Metric], optional
463+
The metrics to be used for validation evaluation. Defaults to None.
464+
test_metrics : Dict[str, Metric], optional
465+
The metrics to be used for testing evaluation. Defaults to None.
466+
aux_output : bool
467+
Whether to include auxiliary output heads in the model. Defaults to True.
468+
aux_output_layers : list[int] | None
469+
The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19].
470+
aux_weights : list[float]
471+
The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3].
467472
468473
"""
469474
super().__init__()
@@ -486,27 +491,11 @@ def __init__(
486491
self.num_classes = num_classes
487492
self.aux_weights = aux_weights
488493

489-
self.log_train_metrics = log_train_metrics
490-
self.log_val_metrics = log_val_metrics
491-
self.log_test_metrics = log_test_metrics
492-
493-
if log_train_metrics:
494-
assert (
495-
train_metrics is not None
496-
), "train_metrics must be provided if log_train_metrics is True"
497-
self.train_metrics = train_metrics
498-
499-
if log_val_metrics:
500-
assert (
501-
val_metrics is not None
502-
), "val_metrics must be provided if log_val_metrics is True"
503-
self.val_metrics = val_metrics
504-
505-
if log_test_metrics:
506-
assert (
507-
test_metrics is not None
508-
), "test_metrics must be provided if log_test_metrics is True"
509-
self.test_metrics = test_metrics
494+
self.metrics = {
495+
"train": train_metrics,
496+
"val": val_metrics,
497+
"test": test_metrics,
498+
}
510499

511500
self.model = _SetR_PUP(
512501
image_size=image_size,
@@ -531,18 +520,20 @@ def __init__(
531520
aux_output_layers=aux_output_layers,
532521
)
533522

534-
self.train_step_outputs = []
535-
self.train_step_labels = []
536-
537-
self.val_step_outputs = []
538-
self.val_step_labels = []
539-
540-
self.test_step_outputs = []
541-
self.test_step_labels = []
542-
543523
def forward(self, x: torch.Tensor) -> torch.Tensor:
544524
return self.model(x)
545525

526+
def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str):
527+
if self.metrics[step_name] is None:
528+
return {}
529+
530+
return {
531+
f"{step_name}_{metric_name}": metric.to(self.device)(
532+
torch.argmax(y_hat, dim=1, keepdim=True), y
533+
)
534+
for metric_name, metric in self.metrics[step_name].items()
535+
}
536+
546537
def _loss_func(
547538
self,
548539
y_hat: (
@@ -577,6 +568,7 @@ def _loss_func(
577568
+ (loss_aux3 * self.aux_weights[2])
578569
)
579570
loss = self.loss_fn(y_hat, y.long())
571+
580572
return loss
581573

582574
def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
@@ -600,86 +592,17 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
600592
y_hat = self.model(x.float())
601593
loss = self._loss_func(y_hat[0], y.squeeze(1))
602594

603-
if step_name == "train":
604-
self.train_step_outputs.append(y_hat[0])
605-
self.train_step_labels.append(y)
606-
elif step_name == "val":
607-
self.val_step_outputs.append(y_hat[0])
608-
self.val_step_labels.append(y)
609-
elif step_name == "test":
610-
self.test_step_outputs.append(y_hat[0])
611-
self.test_step_labels.append(y)
612-
613-
self.log_dict(
614-
{
615-
f"{step_name}_loss": loss,
616-
},
617-
on_step=True,
618-
on_epoch=True,
619-
prog_bar=True,
620-
logger=True,
621-
)
622-
623-
return loss
624-
625-
def on_train_epoch_end(self):
626-
if self.log_train_metrics:
627-
y_hat = torch.cat(self.train_step_outputs)
628-
y = torch.cat(self.train_step_labels)
629-
preds = torch.argmax(y_hat, dim=1, keepdim=True)
630-
self.train_metrics(preds, y)
631-
mIoU = self.train_metrics.compute()
632-
595+
metrics = self._compute_metrics(y_hat[0], y, step_name)
596+
for metric_name, metric_value in metrics.items():
633597
self.log_dict(
634-
{
635-
f"train_metrics": mIoU,
636-
},
598+
{metric_name: metric_value},
637599
on_step=False,
638600
on_epoch=True,
639601
prog_bar=True,
640602
logger=True,
641603
)
642-
self.train_step_outputs.clear()
643-
self.train_step_labels.clear()
644-
645-
def on_validation_epoch_end(self):
646-
if self.log_val_metrics:
647-
y_hat = torch.cat(self.val_step_outputs)
648-
y = torch.cat(self.val_step_labels)
649-
preds = torch.argmax(y_hat, dim=1, keepdim=True)
650-
self.val_metrics(preds, y)
651-
mIoU = self.val_metrics.compute()
652604

653-
self.log_dict(
654-
{
655-
f"val_metrics": mIoU,
656-
},
657-
on_step=False,
658-
on_epoch=True,
659-
prog_bar=True,
660-
logger=True,
661-
)
662-
self.val_step_outputs.clear()
663-
self.val_step_labels.clear()
664-
665-
def on_test_epoch_end(self):
666-
if self.log_test_metrics:
667-
y_hat = torch.cat(self.test_step_outputs)
668-
y = torch.cat(self.test_step_labels)
669-
preds = torch.argmax(y_hat, dim=1, keepdim=True)
670-
self.test_metrics(preds, y)
671-
mIoU = self.test_metrics.compute()
672-
self.log_dict(
673-
{
674-
f"test_metrics": mIoU,
675-
},
676-
on_step=False,
677-
on_epoch=True,
678-
prog_bar=True,
679-
logger=True,
680-
)
681-
self.test_step_outputs.clear()
682-
self.test_step_labels.clear()
605+
return loss
683606

684607
def training_step(self, batch: torch.Tensor, batch_idx: int):
685608
return self._single_step(batch, batch_idx, "train")
@@ -694,7 +617,7 @@ def predict_step(
694617
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None
695618
):
696619
x, _ = batch
697-
return self.model(x)
620+
return self.model(x)[0]
698621

699622
def configure_optimizers(self):
700623
return torch.optim.Adam(self.model.parameters(), lr=1e-3)

tests/models/nets/test_setr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_setr_predict():
2727
preds = model.predict_step((x, mask), 0)
2828
assert preds is not None
2929
assert (
30-
preds[0].shape == mask_shape
30+
preds.shape == mask_shape
3131
), f"Expected shape {mask_shape}, but got {preds[0].shape}"
3232

3333

0 commit comments

Comments
 (0)