1
1
import warnings
2
- from typing import Optional , Tuple
2
+ from typing import Dict , Optional , Tuple
3
3
4
4
import lightning as L
5
5
import torch
6
6
from torch import nn
7
- from torchmetrics import JaccardIndex
7
+ from torchmetrics import JaccardIndex , Metric
8
8
9
9
from minerva .models .nets .vit import _VisionTransformerBackbone
10
10
from minerva .utils .upsample import Upsample , resize
@@ -407,12 +407,9 @@ def __init__(
407
407
conv_act : Optional [nn .Module ] = None ,
408
408
interpolate_mode : str = "bilinear" ,
409
409
loss_fn : Optional [nn .Module ] = None ,
410
- log_train_metrics : bool = False ,
411
- train_metrics : Optional [nn .Module ] = None ,
412
- log_val_metrics : bool = False ,
413
- val_metrics : Optional [nn .Module ] = None ,
414
- log_test_metrics : bool = False ,
415
- test_metrics : Optional [nn .Module ] = None ,
410
+ train_metrics : Optional [Dict [str , Metric ]] = None ,
411
+ val_metrics : Optional [Dict [str , Metric ]] = None ,
412
+ test_metrics : Optional [Dict [str , Metric ]] = None ,
416
413
aux_output : bool = True ,
417
414
aux_output_layers : list [int ] | None = [9 , 14 , 19 ],
418
415
aux_weights : list [float ] = [0.3 , 0.3 , 0.3 ],
@@ -460,10 +457,18 @@ def __init__(
460
457
The interpolation mode for upsampling in the decoder. Defaults to "bilinear".
461
458
loss_fn : nn.Module, optional
462
459
The loss function to be used during training. Defaults to None.
463
- log_metrics : bool
464
- Whether to log metrics during training. Defaults to True.
465
- metrics : list[MetricTypeSetR], optional
466
- The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU].
460
+ train_metrics : Dict[str, Metric], optional
461
+ The metrics to be used for training evaluation. Defaults to None.
462
+ val_metrics : Dict[str, Metric], optional
463
+ The metrics to be used for validation evaluation. Defaults to None.
464
+ test_metrics : Dict[str, Metric], optional
465
+ The metrics to be used for testing evaluation. Defaults to None.
466
+ aux_output : bool
467
+ Whether to include auxiliary output heads in the model. Defaults to True.
468
+ aux_output_layers : list[int] | None
469
+ The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19].
470
+ aux_weights : list[float]
471
+ The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3].
467
472
468
473
"""
469
474
super ().__init__ ()
@@ -486,27 +491,11 @@ def __init__(
486
491
self .num_classes = num_classes
487
492
self .aux_weights = aux_weights
488
493
489
- self .log_train_metrics = log_train_metrics
490
- self .log_val_metrics = log_val_metrics
491
- self .log_test_metrics = log_test_metrics
492
-
493
- if log_train_metrics :
494
- assert (
495
- train_metrics is not None
496
- ), "train_metrics must be provided if log_train_metrics is True"
497
- self .train_metrics = train_metrics
498
-
499
- if log_val_metrics :
500
- assert (
501
- val_metrics is not None
502
- ), "val_metrics must be provided if log_val_metrics is True"
503
- self .val_metrics = val_metrics
504
-
505
- if log_test_metrics :
506
- assert (
507
- test_metrics is not None
508
- ), "test_metrics must be provided if log_test_metrics is True"
509
- self .test_metrics = test_metrics
494
+ self .metrics = {
495
+ "train" : train_metrics ,
496
+ "val" : val_metrics ,
497
+ "test" : test_metrics ,
498
+ }
510
499
511
500
self .model = _SetR_PUP (
512
501
image_size = image_size ,
@@ -531,18 +520,20 @@ def __init__(
531
520
aux_output_layers = aux_output_layers ,
532
521
)
533
522
534
- self .train_step_outputs = []
535
- self .train_step_labels = []
536
-
537
- self .val_step_outputs = []
538
- self .val_step_labels = []
539
-
540
- self .test_step_outputs = []
541
- self .test_step_labels = []
542
-
543
523
def forward (self , x : torch .Tensor ) -> torch .Tensor :
544
524
return self .model (x )
545
525
526
+ def _compute_metrics (self , y_hat : torch .Tensor , y : torch .Tensor , step_name : str ):
527
+ if self .metrics [step_name ] is None :
528
+ return {}
529
+
530
+ return {
531
+ f"{ step_name } _{ metric_name } " : metric .to (self .device )(
532
+ torch .argmax (y_hat , dim = 1 , keepdim = True ), y
533
+ )
534
+ for metric_name , metric in self .metrics [step_name ].items ()
535
+ }
536
+
546
537
def _loss_func (
547
538
self ,
548
539
y_hat : (
@@ -577,6 +568,7 @@ def _loss_func(
577
568
+ (loss_aux3 * self .aux_weights [2 ])
578
569
)
579
570
loss = self .loss_fn (y_hat , y .long ())
571
+
580
572
return loss
581
573
582
574
def _single_step (self , batch : torch .Tensor , batch_idx : int , step_name : str ):
@@ -600,86 +592,17 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
600
592
y_hat = self .model (x .float ())
601
593
loss = self ._loss_func (y_hat [0 ], y .squeeze (1 ))
602
594
603
- if step_name == "train" :
604
- self .train_step_outputs .append (y_hat [0 ])
605
- self .train_step_labels .append (y )
606
- elif step_name == "val" :
607
- self .val_step_outputs .append (y_hat [0 ])
608
- self .val_step_labels .append (y )
609
- elif step_name == "test" :
610
- self .test_step_outputs .append (y_hat [0 ])
611
- self .test_step_labels .append (y )
612
-
613
- self .log_dict (
614
- {
615
- f"{ step_name } _loss" : loss ,
616
- },
617
- on_step = True ,
618
- on_epoch = True ,
619
- prog_bar = True ,
620
- logger = True ,
621
- )
622
-
623
- return loss
624
-
625
- def on_train_epoch_end (self ):
626
- if self .log_train_metrics :
627
- y_hat = torch .cat (self .train_step_outputs )
628
- y = torch .cat (self .train_step_labels )
629
- preds = torch .argmax (y_hat , dim = 1 , keepdim = True )
630
- self .train_metrics (preds , y )
631
- mIoU = self .train_metrics .compute ()
632
-
595
+ metrics = self ._compute_metrics (y_hat [0 ], y , step_name )
596
+ for metric_name , metric_value in metrics .items ():
633
597
self .log_dict (
634
- {
635
- f"train_metrics" : mIoU ,
636
- },
598
+ {metric_name : metric_value },
637
599
on_step = False ,
638
600
on_epoch = True ,
639
601
prog_bar = True ,
640
602
logger = True ,
641
603
)
642
- self .train_step_outputs .clear ()
643
- self .train_step_labels .clear ()
644
-
645
- def on_validation_epoch_end (self ):
646
- if self .log_val_metrics :
647
- y_hat = torch .cat (self .val_step_outputs )
648
- y = torch .cat (self .val_step_labels )
649
- preds = torch .argmax (y_hat , dim = 1 , keepdim = True )
650
- self .val_metrics (preds , y )
651
- mIoU = self .val_metrics .compute ()
652
604
653
- self .log_dict (
654
- {
655
- f"val_metrics" : mIoU ,
656
- },
657
- on_step = False ,
658
- on_epoch = True ,
659
- prog_bar = True ,
660
- logger = True ,
661
- )
662
- self .val_step_outputs .clear ()
663
- self .val_step_labels .clear ()
664
-
665
- def on_test_epoch_end (self ):
666
- if self .log_test_metrics :
667
- y_hat = torch .cat (self .test_step_outputs )
668
- y = torch .cat (self .test_step_labels )
669
- preds = torch .argmax (y_hat , dim = 1 , keepdim = True )
670
- self .test_metrics (preds , y )
671
- mIoU = self .test_metrics .compute ()
672
- self .log_dict (
673
- {
674
- f"test_metrics" : mIoU ,
675
- },
676
- on_step = False ,
677
- on_epoch = True ,
678
- prog_bar = True ,
679
- logger = True ,
680
- )
681
- self .test_step_outputs .clear ()
682
- self .test_step_labels .clear ()
605
+ return loss
683
606
684
607
def training_step (self , batch : torch .Tensor , batch_idx : int ):
685
608
return self ._single_step (batch , batch_idx , "train" )
@@ -694,7 +617,7 @@ def predict_step(
694
617
self , batch : torch .Tensor , batch_idx : int , dataloader_idx : int | None = None
695
618
):
696
619
x , _ = batch
697
- return self .model (x )
620
+ return self .model (x )[ 0 ]
698
621
699
622
def configure_optimizers (self ):
700
623
return torch .optim .Adam (self .model .parameters (), lr = 1e-3 )
0 commit comments