1616
1717from mmlearn .datasets .core import Modalities
1818from mmlearn .modules .layers import MLP
19+ from mmlearn .tasks .base import TrainingTask
20+
21+ from mmlearn .tasks .zero_shot_classification import ZeroShotClassification
1922
2023
2124def extract_vision_encoder (
2225 encoder : Any ,
2326 model_checkpoint_path : Optional [str ],
27+ modality_to_extract : Optional [str ] = "rgb" ,
2428 keys_to_remove : Optional [List [str ]] = None ,
2529 keys_to_rename : Optional [Dict [str , str ]] = None , # Default for renaming
26- keys_to_ignore : Optional [List [str ]] = None ,
2730) -> nn .Module :
2831 """
2932 Extract the vision encoder from a PyTorch Lightning model.
@@ -61,12 +64,6 @@ def extract_vision_encoder(
6164 k : v for k , v in state_dict .items () if k not in keys_to_remove
6265 }
6366
64- # Ignore specific keys
65- if keys_to_ignore :
66- state_dict = {
67- k : v for k , v in state_dict .items () if k not in keys_to_ignore
68- }
69-
7067 # Rename keys based on input mappings
7168 if keys_to_rename :
7269 state_dict = {
@@ -78,15 +75,15 @@ def extract_vision_encoder(
7875
7976 try :
8077 if state_dict :
81- model ["rgb" ].load_state_dict (state_dict , strict = True )
78+ model [modality_to_extract ].load_state_dict (state_dict , strict = True )
8279 print ("Encoder state dict loaded successfully" )
8380 except Exception as e :
8481 print (f"Error loading state dict: { e } " )
85- return model ["rgb" ]
82+ return model [modality_to_extract ]
8683
8784
8885@store (group = "task" , provider = "mmlearn" )
89- class LinearClassifierModule ( L . LightningModule ):
86+ class LinearClassifier ( TrainingTask ):
9087 """A linear classifier module for evaluating pretrained encoders.
9188
9289 Parameters
@@ -98,7 +95,7 @@ class LinearClassifierModule(L.LightningModule):
9895 `common.constants.Modality` for valid values. The target label key is
9996 inferred from this modality. This means that, for example, that if the
10097 modality is 'rgb', the target label key is expected to be 'rgb_target'.
101- num_output_features : int
98+ embed_dim : int
10299 Output features from the encoder, defining the linear classifier's input size.
103100 num_classes : int
104101 Number of classes for the classification task.
@@ -154,26 +151,27 @@ class LinearClassifierModule(L.LightningModule):
154151
155152 def __init__ (
156153 self ,
157- # encoder: torch.nn.Module,
158154 encoder : nn .Module ,
159- model_checkpoint_path : Optional [str ], # change name
155+ model_checkpoint_path : Optional [str ],
160156 modality : str ,
161- num_output_features : int ,
157+ embed_dim : int ,
162158 num_classes : int ,
163159 hidden_dims : Optional [List [int ]] = None ,
164160 task : Literal ["binary" , "multiclass" , "multilabel" ] = "multiclass" ,
165161 freeze_encoder : bool = True ,
166- pre_classifier_batch_norm : bool = False ,
162+ keys_to_remove : Optional [Dict [str , str ]] = None ,
163+ keys_to_rename : Optional [Dict [str , str ]] = {"encoders.rgb." : "" },
167164 top_k_list : Optional [List [int ]] = None ,
168165 optimizer : Optional [partial [torch .optim .Optimizer ]] = None ,
166+ pre_classifier_batch_norm : bool = False ,
169167 lr_scheduler : Optional [
170168 Union [
171169 Dict [str , partial [torch .optim .lr_scheduler .LRScheduler ]],
172170 partial [torch .optim .lr_scheduler .LRScheduler ],
173171 ]
174172 ] = None ,
175173 ):
176- super ().__init__ ()
174+ super ().__init__ (loss_fn = nn . CrossEntropyLoss () )
177175 assert task in ["binary" , "multiclass" , "multilabel" ], (
178176 f"Invalid task type: { task } . "
179177 "Expected one of ['binary', 'multiclass', 'multilabel']."
@@ -182,16 +180,13 @@ def __init__(
182180 self .modality = modality
183181
184182 self .encoder : nn .Module = extract_vision_encoder (
185- encoder , model_checkpoint_path , keys_to_rename = {"encoders.rgb." : "" }
183+ encoder , model_checkpoint_path , keys_to_rename = keys_to_rename ,
184+ keys_to_remove = keys_to_remove ,
186185 )
187186
188- linear_layer = MLP (num_output_features , num_classes , hidden_dims )
187+ linear_layer = MLP (embed_dim , num_classes , hidden_dims ,
188+ norm_layer = nn .BatchNorm1d if pre_classifier_batch_norm else None )
189189
190- if pre_classifier_batch_norm :
191- linear_layer = nn .Sequential (
192- nn .BatchNorm1d (num_output_features , affine = False ),
193- linear_layer ,
194- )
195190 self .classifier = linear_layer
196191
197192 self .freeze_encoder = freeze_encoder
@@ -201,61 +196,67 @@ def __init__(
201196 for param in self .encoder .parameters ():
202197 param .requires_grad = False
203198
204- self .loss_fn = nn .CrossEntropyLoss ()
199+ if task == "multilabel" :
200+ self .loss_fn = nn .BCEWithLogitsLoss ()
201+
205202
206203 self .top_k_list = top_k_list
207- if task == "multiclass" :
208- if self .top_k_list is None :
209- self .top_k_list = [1 , 5 ]
210- accuracy_metrics = {
211- f"top_{ k } _accuracy" : Accuracy (
212- task = task , num_classes = num_classes , top_k = k
213- )
214- for k in self .top_k_list
215- }
216-
217- # Additional metrics for multiclass classification
218- additional_metrics = {
219- "precision" : Precision (
220- task = task , num_classes = num_classes , average = "macro"
221- ),
222- "recall" : Recall (task = task , num_classes = num_classes , average = "macro" ),
223- "f1_score" : F1Score (
224- task = task , num_classes = num_classes , average = "macro"
225- ),
226- "auc" : AUROC (
227- task = task , num_classes = num_classes , average = "macro"
228- ), # AUROC for multiclass
229- }
230-
231- elif task == "multilabel" :
232- # Accuracy and other metrics for multilabel classification
233- accuracy_metrics = {"accuracy" : Accuracy (task = task , num_labels = num_classes )}
234-
235- # Additional metrics for multilabel classification
236- additional_metrics = {
237- "precision" : Precision (
238- task = task , num_labels = num_classes , average = "macro"
239- ),
240- "recall" : Recall (task = task , num_labels = num_classes , average = "macro" ),
241- "f1_score" : F1Score (task = task , num_labels = num_classes , average = "macro" ),
242- "auc" : AUROC (task = task , num_labels = num_classes ), # AUC for multilabel
243- }
244-
245- else : # binary
246- # Accuracy and other metrics for binary classification
247- accuracy_metrics = {"accuracy" : Accuracy (task = task )}
248-
249- # Additional metrics for binary classification
250- additional_metrics = {
251- "precision" : Precision (task = task ),
252- "recall" : Recall (task = task ),
253- "f1_score" : F1Score (task = task ),
254- "auc" : AUROC (task = task ), # AUROC for binary classification
255- }
204+ # if task == "multiclass":
205+ # if self.top_k_list is None:
206+ # self.top_k_list = [1, 5]
207+ # accuracy_metrics = {
208+ # f"top_{k}_accuracy": Accuracy(
209+ # task=task, num_classes=num_classes, top_k=k
210+ # )
211+ # for k in self.top_k_list
212+ # }
213+
214+ # # Additional metrics for multiclass classification
215+ # additional_metrics = {
216+ # "precision": Precision(
217+ # task=task, num_classes=num_classes, average="macro"
218+ # ),
219+ # "recall": Recall(task=task, num_classes=num_classes, average="macro"),
220+ # "f1_score": F1Score(
221+ # task=task, num_classes=num_classes, average="macro"
222+ # ),
223+ # "auc": AUROC(
224+ # task=task, num_classes=num_classes, average="macro"
225+ # ), # AUROC for multiclass
226+ # }
227+
228+ # elif task == "multilabel":
229+ # # Accuracy and other metrics for multilabel classification
230+ # accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)}
231+
232+ # # Additional metrics for multilabel classification
233+ # additional_metrics = {
234+ # "precision": Precision(
235+ # task=task, num_labels=num_classes, average="macro"
236+ # ),
237+ # "recall": Recall(task=task, num_labels=num_classes, average="macro"),
238+ # "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"),
239+ # "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel
240+ # }
241+
242+ # else: # binary
243+ # # Accuracy and other metrics for binary classification
244+ # accuracy_metrics = {"accuracy": Accuracy(task=task)}
245+
246+ # # Additional metrics for binary classification
247+ # additional_metrics = {
248+ # "precision": Precision(task=task),
249+ # "recall": Recall(task=task),
250+ # "f1_score": F1Score(task=task),
251+ # "auc": AUROC(task=task), # AUROC for binary classification
252+ # }
256253
257254 # combine all metrics
258- metrics = MetricCollection ({** accuracy_metrics , ** additional_metrics })
255+ # metrics = MetricCollection({**accuracy_metrics, **additional_metrics})
256+ metrics = ZeroShotClassification ._create_metrics (num_classes = num_classes ,
257+ top_k = self .top_k_list ,
258+ prefix = "" ,
259+ postfix = "" ,)
259260 self .train_metrics = metrics .clone (prefix = "train/" )
260261 self .valid_metrics = metrics .clone (prefix = "val/" )
261262
@@ -349,12 +350,40 @@ def validation_step(
349350 The loss computed for the batch.
350351 """
351352 logits , y = self ._get_logits_and_labels (batch )
352-
353+
353354 loss : torch .Tensor = self .loss_fn (logits , y )
354355 self .log ("val/loss" , self .all_gather (loss .clone ().detach ()).mean ())
355356
356357 self .valid_metrics .update (logits , y )
357358 return loss
359+
360+ def test_step (
361+ self ,
362+ batch : Dict [str , torch .Tensor ],
363+ batch_idx : int ,
364+ ) -> torch .Tensor :
365+ """
366+ Execute a test step using a single batch.
367+
368+ Parameters
369+ ----------
370+ batch : Dict[str, torch.Tensor]
371+ The current batch of test data, including input tensors and labels.
372+ batch_idx : int
373+ The index of the current test batch.
374+
375+ Returns
376+ -------
377+ torch.Tensor
378+ The loss computed for the batch.
379+ """
380+ logits , y = self ._get_logits_and_labels (batch )
381+
382+ loss : torch .Tensor = self .loss_fn (logits , y )
383+ self .log ("val/loss" , self .all_gather (loss .clone ().detach ()).mean ())
384+
385+ self .test_metrics .update (logits , y )
386+ return loss
358387
359388 def on_validation_epoch_end (self ) -> None :
360389 """Compute validation metrics accumulated over the epoch."""
@@ -363,6 +392,15 @@ def on_validation_epoch_end(self) -> None:
363392 print (f" { metric } : { value .item ()} " )
364393 self .log_dict (val_metrics )
365394 self .valid_metrics .reset ()
395+
396+
397+ def on_test_epoch_end (self ) -> None :
398+ """Compute test metrics accumulated over the epoch."""
399+ val_metrics = self .test_metrics .compute ()
400+ for metric , value in val_metrics .items ():
401+ print (f" { metric } : { value .item ()} " )
402+ self .log_dict (val_metrics )
403+ self .test_metrics .reset ()
366404
367405 def configure_optimizers (self ) -> OptimizerLRScheduler : # noqa: PLR0912
368406 """Configure the optimizer and learning rate scheduler."""
0 commit comments