Skip to content

Commit cdc2e82

Browse files
committed
Fix issues raised in PR review
1 parent 63b036e commit cdc2e82

File tree

3 files changed

+126
-313
lines changed

3 files changed

+126
-313
lines changed

mmlearn/tasks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining
44
from mmlearn.tasks.ijepa import IJEPA
5-
from mmlearn.tasks.linear_probing import LinearClassifierModule
5+
from mmlearn.tasks.linear_probing import LinearClassifier
66
from mmlearn.tasks.zero_shot_classification import ZeroShotClassification
77
from mmlearn.tasks.zero_shot_retrieval import ZeroShotCrossModalRetrieval
88

@@ -12,5 +12,5 @@
1212
"IJEPA",
1313
"ZeroShotCrossModalRetrieval",
1414
"ZeroShotClassification",
15-
"LinearClassifierModule",
15+
"LinearClassifier",
1616
]

mmlearn/tasks/linear_probing.py

Lines changed: 113 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717
from mmlearn.datasets.core import Modalities
1818
from mmlearn.modules.layers import MLP
19+
from mmlearn.tasks.base import TrainingTask
20+
21+
from mmlearn.tasks.zero_shot_classification import ZeroShotClassification
1922

2023

2124
def extract_vision_encoder(
2225
encoder: Any,
2326
model_checkpoint_path: Optional[str],
27+
modality_to_extract: Optional[str] = "rgb",
2428
keys_to_remove: Optional[List[str]] = None,
2529
keys_to_rename: Optional[Dict[str, str]] = None, # Default for renaming
26-
keys_to_ignore: Optional[List[str]] = None,
2730
) -> nn.Module:
2831
"""
2932
Extract the vision encoder from a PyTorch Lightning model.
@@ -61,12 +64,6 @@ def extract_vision_encoder(
6164
k: v for k, v in state_dict.items() if k not in keys_to_remove
6265
}
6366

64-
# Ignore specific keys
65-
if keys_to_ignore:
66-
state_dict = {
67-
k: v for k, v in state_dict.items() if k not in keys_to_ignore
68-
}
69-
7067
# Rename keys based on input mappings
7168
if keys_to_rename:
7269
state_dict = {
@@ -78,15 +75,15 @@ def extract_vision_encoder(
7875

7976
try:
8077
if state_dict:
81-
model["rgb"].load_state_dict(state_dict, strict=True)
78+
model[modality_to_extract].load_state_dict(state_dict, strict=True)
8279
print("Encoder state dict loaded successfully")
8380
except Exception as e:
8481
print(f"Error loading state dict: {e}")
85-
return model["rgb"]
82+
return model[modality_to_extract]
8683

8784

8885
@store(group="task", provider="mmlearn")
89-
class LinearClassifierModule(L.LightningModule):
86+
class LinearClassifier(TrainingTask):
9087
"""A linear classifier module for evaluating pretrained encoders.
9188
9289
Parameters
@@ -98,7 +95,7 @@ class LinearClassifierModule(L.LightningModule):
9895
`common.constants.Modality` for valid values. The target label key is
9996
inferred from this modality. This means that, for example, that if the
10097
modality is 'rgb', the target label key is expected to be 'rgb_target'.
101-
num_output_features : int
98+
embed_dim : int
10299
Output features from the encoder, defining the linear classifier's input size.
103100
num_classes : int
104101
Number of classes for the classification task.
@@ -154,26 +151,27 @@ class LinearClassifierModule(L.LightningModule):
154151

155152
def __init__(
156153
self,
157-
# encoder: torch.nn.Module,
158154
encoder: nn.Module,
159-
model_checkpoint_path: Optional[str], # change name
155+
model_checkpoint_path: Optional[str],
160156
modality: str,
161-
num_output_features: int,
157+
embed_dim: int,
162158
num_classes: int,
163159
hidden_dims: Optional[List[int]] = None,
164160
task: Literal["binary", "multiclass", "multilabel"] = "multiclass",
165161
freeze_encoder: bool = True,
166-
pre_classifier_batch_norm: bool = False,
162+
keys_to_remove: Optional[Dict[str, str]] = None,
163+
keys_to_rename: Optional[Dict[str, str]] = {"encoders.rgb.": ""},
167164
top_k_list: Optional[List[int]] = None,
168165
optimizer: Optional[partial[torch.optim.Optimizer]] = None,
166+
pre_classifier_batch_norm: bool = False,
169167
lr_scheduler: Optional[
170168
Union[
171169
Dict[str, partial[torch.optim.lr_scheduler.LRScheduler]],
172170
partial[torch.optim.lr_scheduler.LRScheduler],
173171
]
174172
] = None,
175173
):
176-
super().__init__()
174+
super().__init__(loss_fn=nn.CrossEntropyLoss())
177175
assert task in ["binary", "multiclass", "multilabel"], (
178176
f"Invalid task type: {task}. "
179177
"Expected one of ['binary', 'multiclass', 'multilabel']."
@@ -182,16 +180,13 @@ def __init__(
182180
self.modality = modality
183181

184182
self.encoder: nn.Module = extract_vision_encoder(
185-
encoder, model_checkpoint_path, keys_to_rename={"encoders.rgb.": ""}
183+
encoder, model_checkpoint_path, keys_to_rename=keys_to_rename,
184+
keys_to_remove=keys_to_remove,
186185
)
187186

188-
linear_layer = MLP(num_output_features, num_classes, hidden_dims)
187+
linear_layer = MLP(embed_dim, num_classes, hidden_dims,
188+
norm_layer=nn.BatchNorm1d if pre_classifier_batch_norm else None)
189189

190-
if pre_classifier_batch_norm:
191-
linear_layer = nn.Sequential(
192-
nn.BatchNorm1d(num_output_features, affine=False),
193-
linear_layer,
194-
)
195190
self.classifier = linear_layer
196191

197192
self.freeze_encoder = freeze_encoder
@@ -201,61 +196,67 @@ def __init__(
201196
for param in self.encoder.parameters():
202197
param.requires_grad = False
203198

204-
self.loss_fn = nn.CrossEntropyLoss()
199+
if task == "multilabel":
200+
self.loss_fn = nn.BCEWithLogitsLoss()
201+
205202

206203
self.top_k_list = top_k_list
207-
if task == "multiclass":
208-
if self.top_k_list is None:
209-
self.top_k_list = [1, 5]
210-
accuracy_metrics = {
211-
f"top_{k}_accuracy": Accuracy(
212-
task=task, num_classes=num_classes, top_k=k
213-
)
214-
for k in self.top_k_list
215-
}
216-
217-
# Additional metrics for multiclass classification
218-
additional_metrics = {
219-
"precision": Precision(
220-
task=task, num_classes=num_classes, average="macro"
221-
),
222-
"recall": Recall(task=task, num_classes=num_classes, average="macro"),
223-
"f1_score": F1Score(
224-
task=task, num_classes=num_classes, average="macro"
225-
),
226-
"auc": AUROC(
227-
task=task, num_classes=num_classes, average="macro"
228-
), # AUROC for multiclass
229-
}
230-
231-
elif task == "multilabel":
232-
# Accuracy and other metrics for multilabel classification
233-
accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)}
234-
235-
# Additional metrics for multilabel classification
236-
additional_metrics = {
237-
"precision": Precision(
238-
task=task, num_labels=num_classes, average="macro"
239-
),
240-
"recall": Recall(task=task, num_labels=num_classes, average="macro"),
241-
"f1_score": F1Score(task=task, num_labels=num_classes, average="macro"),
242-
"auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel
243-
}
244-
245-
else: # binary
246-
# Accuracy and other metrics for binary classification
247-
accuracy_metrics = {"accuracy": Accuracy(task=task)}
248-
249-
# Additional metrics for binary classification
250-
additional_metrics = {
251-
"precision": Precision(task=task),
252-
"recall": Recall(task=task),
253-
"f1_score": F1Score(task=task),
254-
"auc": AUROC(task=task), # AUROC for binary classification
255-
}
204+
# if task == "multiclass":
205+
# if self.top_k_list is None:
206+
# self.top_k_list = [1, 5]
207+
# accuracy_metrics = {
208+
# f"top_{k}_accuracy": Accuracy(
209+
# task=task, num_classes=num_classes, top_k=k
210+
# )
211+
# for k in self.top_k_list
212+
# }
213+
214+
# # Additional metrics for multiclass classification
215+
# additional_metrics = {
216+
# "precision": Precision(
217+
# task=task, num_classes=num_classes, average="macro"
218+
# ),
219+
# "recall": Recall(task=task, num_classes=num_classes, average="macro"),
220+
# "f1_score": F1Score(
221+
# task=task, num_classes=num_classes, average="macro"
222+
# ),
223+
# "auc": AUROC(
224+
# task=task, num_classes=num_classes, average="macro"
225+
# ), # AUROC for multiclass
226+
# }
227+
228+
# elif task == "multilabel":
229+
# # Accuracy and other metrics for multilabel classification
230+
# accuracy_metrics = {"accuracy": Accuracy(task=task, num_labels=num_classes)}
231+
232+
# # Additional metrics for multilabel classification
233+
# additional_metrics = {
234+
# "precision": Precision(
235+
# task=task, num_labels=num_classes, average="macro"
236+
# ),
237+
# "recall": Recall(task=task, num_labels=num_classes, average="macro"),
238+
# "f1_score": F1Score(task=task, num_labels=num_classes, average="macro"),
239+
# "auc": AUROC(task=task, num_labels=num_classes), # AUC for multilabel
240+
# }
241+
242+
# else: # binary
243+
# # Accuracy and other metrics for binary classification
244+
# accuracy_metrics = {"accuracy": Accuracy(task=task)}
245+
246+
# # Additional metrics for binary classification
247+
# additional_metrics = {
248+
# "precision": Precision(task=task),
249+
# "recall": Recall(task=task),
250+
# "f1_score": F1Score(task=task),
251+
# "auc": AUROC(task=task), # AUROC for binary classification
252+
# }
256253

257254
# combine all metrics
258-
metrics = MetricCollection({**accuracy_metrics, **additional_metrics})
255+
# metrics = MetricCollection({**accuracy_metrics, **additional_metrics})
256+
metrics = ZeroShotClassification._create_metrics(num_classes=num_classes,
257+
top_k=self.top_k_list,
258+
prefix="",
259+
postfix="",)
259260
self.train_metrics = metrics.clone(prefix="train/")
260261
self.valid_metrics = metrics.clone(prefix="val/")
261262

@@ -349,12 +350,40 @@ def validation_step(
349350
The loss computed for the batch.
350351
"""
351352
logits, y = self._get_logits_and_labels(batch)
352-
353+
353354
loss: torch.Tensor = self.loss_fn(logits, y)
354355
self.log("val/loss", self.all_gather(loss.clone().detach()).mean())
355356

356357
self.valid_metrics.update(logits, y)
357358
return loss
359+
360+
def test_step(
361+
self,
362+
batch: Dict[str, torch.Tensor],
363+
batch_idx: int,
364+
) -> torch.Tensor:
365+
"""
366+
Execute a test step using a single batch.
367+
368+
Parameters
369+
----------
370+
batch : Dict[str, torch.Tensor]
371+
The current batch of test data, including input tensors and labels.
372+
batch_idx : int
373+
The index of the current test batch.
374+
375+
Returns
376+
-------
377+
torch.Tensor
378+
The loss computed for the batch.
379+
"""
380+
logits, y = self._get_logits_and_labels(batch)
381+
382+
loss: torch.Tensor = self.loss_fn(logits, y)
383+
self.log("val/loss", self.all_gather(loss.clone().detach()).mean())
384+
385+
self.test_metrics.update(logits, y)
386+
return loss
358387

359388
def on_validation_epoch_end(self) -> None:
360389
"""Compute validation metrics accumulated over the epoch."""
@@ -363,6 +392,15 @@ def on_validation_epoch_end(self) -> None:
363392
print(f" {metric}: {value.item()}")
364393
self.log_dict(val_metrics)
365394
self.valid_metrics.reset()
395+
396+
397+
def on_test_epoch_end(self) -> None:
398+
"""Compute test metrics accumulated over the epoch."""
399+
val_metrics = self.test_metrics.compute()
400+
for metric, value in val_metrics.items():
401+
print(f" {metric}: {value.item()}")
402+
self.log_dict(val_metrics)
403+
self.test_metrics.reset()
366404

367405
def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912
368406
"""Configure the optimizer and learning rate scheduler."""

0 commit comments

Comments
 (0)