Skip to content

Commit d606334

Browse files
authored
Merge pull request #89 from discovery-unicamp/merging-from-dev-12052025
Merging changes from development version (12/05/2025)
2 parents 51c3fe6 + d32f546 commit d606334

File tree

7 files changed

+274
-167
lines changed

7 files changed

+274
-167
lines changed

minerva/models/nets/mlp.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,97 @@
1-
from torch import nn
2-
from typing import Sequence
1+
import torch.nn as nn
2+
from typing import Sequence, Optional, List
33

44

55
class MLP(nn.Sequential):
66
"""
7-
A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential.
7+
A flexible multilayer perceptron (MLP) implemented as a subclass of nn.Sequential.
88
9-
This MLP is composed of a sequence of linear layers interleaved with ReLU activation
10-
functions, except for the final layer which remains purely linear.
9+
This class allows you to quickly build an MLP with:
10+
- Custom layer sizes
11+
- Configurable activation functions
12+
- Optional intermediate operations (e.g., BatchNorm, Dropout) after each linear layer
13+
- An optional final operation (e.g., normalization, final activation)
14+
15+
Parameters
16+
----------
17+
layer_sizes : Sequence[int]
18+
A list of integers specifying the sizes of each layer. Must contain at least two values:
19+
the input and output dimensions.
20+
activation_cls : type, optional
21+
The activation function class (must inherit from nn.Module) to use between layers.
22+
Defaults to nn.ReLU.
23+
intermediate_ops : Optional[List[Optional[nn.Module]]], optional
24+
A list of modules (e.g., nn.BatchNorm1d, nn.Dropout) to apply after each linear layer
25+
and before the activation. Each item corresponds to one linear layer. Use `None` to skip
26+
an operation for that layer. Must be the same length as the number of linear layers.
27+
final_op : Optional[nn.Module], optional
28+
A module to apply after the last layer (e.g., a final activation or normalization).
29+
30+
*args, **kwargs :
31+
Additional arguments passed to the activation function constructor.
1132
1233
Example
1334
-------
14-
15-
>>> mlp = MLP(10, 20, 30, 40)
35+
>>> from torch import nn
36+
>>> mlp = MLP(
37+
... [128, 256, 64, 10],
38+
... activation_cls=nn.ReLU,
39+
... intermediate_ops=[nn.BatchNorm1d(256), nn.BatchNorm1d(64), None],
40+
... final_op=nn.Sigmoid()
41+
... )
1642
>>> print(mlp)
1743
MLP(
18-
(0): Linear(in_features=10, out_features=20, bias=True)
19-
(1): ReLU()
20-
(2): Linear(in_features=20, out_features=30, bias=True)
21-
(3): ReLU()
22-
(4): Linear(in_features=30, out_features=40, bias=True)
44+
(0): Linear(in_features=128, out_features=256, bias=True)
45+
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
46+
(2): ReLU()
47+
(3): Linear(in_features=256, out_features=64, bias=True)
48+
(4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
49+
(5): ReLU()
50+
(6): Linear(in_features=64, out_features=10, bias=True)
51+
(7): Sigmoid()
2352
)
2453
"""
2554

2655
def __init__(
2756
self,
2857
layer_sizes: Sequence[int],
2958
activation_cls: type = nn.ReLU,
59+
intermediate_ops: Optional[List[Optional[nn.Module]]] = None,
60+
final_op: Optional[nn.Module] = None,
3061
*args,
3162
**kwargs,
3263
):
33-
"""
34-
Initializes the MLP with specified layer sizes.
35-
36-
Parameters
37-
----------
38-
layer_sizes : Sequence[int]
39-
A sequence of positive integers indicating the size of each layer.
40-
At least two integers are required, representing the input and output layers.
41-
activation_cls : type
42-
The class of the activation function to use between layers. Default is nn.ReLU.
43-
*args
44-
Additional arguments passed to the activation function.
45-
**kwargs
46-
Additional keyword arguments passed to the activation function.
47-
48-
Raises
49-
------
50-
AssertionError
51-
If fewer than two layer sizes are provided or if any layer size is not a positive integer.
52-
AssertionError
53-
If activation_cls does not inherit from torch.nn.Module.
54-
"""
5564

5665
assert (
5766
len(layer_sizes) >= 2
5867
), "Multilayer perceptron must have at least 2 layers"
5968
assert all(
60-
ls > 0 and isinstance(ls, int) for ls in layer_sizes
69+
isinstance(ls, int) and ls > 0 for ls in layer_sizes
6170
), "All layer sizes must be positive integers"
62-
6371
assert issubclass(
6472
activation_cls, nn.Module
6573
), "activation_cls must inherit from torch.nn.Module"
6674

75+
num_layers = len(layer_sizes) - 1
76+
77+
if intermediate_ops is not None:
78+
if len(intermediate_ops) != num_layers:
79+
raise ValueError(
80+
f"Length of intermediate_ops ({len(intermediate_ops)}) must match number of layers ({num_layers})"
81+
)
82+
6783
layers = []
68-
for i in range(len(layer_sizes) - 2):
69-
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
70-
layers.append(activation_cls(*args, **kwargs))
71-
layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))
84+
for i in range(num_layers):
85+
in_dim, out_dim = layer_sizes[i], layer_sizes[i + 1]
86+
layers.append(nn.Linear(in_dim, out_dim))
87+
88+
if intermediate_ops is not None and intermediate_ops[i] is not None:
89+
layers.append(intermediate_ops[i])
90+
91+
if activation_cls is not None:
92+
layers.append(activation_cls(*args, **kwargs))
93+
94+
if final_op is not None:
95+
layers.append(final_op)
7296

7397
super().__init__(*layers)

minerva/models/ssl/byol.py

Lines changed: 46 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -8,136 +8,82 @@
88
from torch import nn
99
from torch import Tensor
1010
from collections import OrderedDict
11-
from typing import List, Optional, Sequence, Tuple, Union
11+
from typing import Optional, Sequence
1212

1313
from minerva.losses.negative_cossine_similatiry import NegativeCosineSimilarity
14-
15-
# --- Model Parts ---------------------------------------------------------
16-
17-
# Borrowed from https://github.com/lightly-ai/lightly/blob/master/lightly/models/modules/heads.py#L15
18-
19-
20-
class ProjectionHead(nn.Module):
21-
"""Base class for all projection and prediction heads."""
22-
23-
def __init__(
24-
self,
25-
blocks: Sequence[
26-
Union[
27-
Tuple[int, int, Optional[nn.Module], Optional[nn.Module]],
28-
Tuple[int, int, Optional[nn.Module], Optional[nn.Module], bool],
29-
],
30-
],
31-
) -> None:
32-
super().__init__()
33-
34-
layers: List[nn.Module] = []
35-
for block in blocks:
36-
input_dim, output_dim, batch_norm, non_linearity, *bias = block
37-
use_bias = bias[0] if bias else not bool(batch_norm)
38-
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
39-
if batch_norm:
40-
layers.append(batch_norm)
41-
if non_linearity:
42-
layers.append(non_linearity)
43-
self.layers = nn.Sequential(*layers)
44-
45-
def preprocess_step(self, x: Tensor) -> Tensor:
46-
return x
47-
48-
def forward(self, x: Tensor) -> Tensor:
49-
x = self.preprocess_step(x)
50-
projection: Tensor = self.layers(x)
51-
return projection
52-
53-
54-
class BYOLProjectionHead(ProjectionHead):
55-
"""Projection head used for BYOL.
56-
"This MLP consists in a linear layer with output size 4096 followed by
57-
batch normalization, rectified linear units (ReLU), and a final
58-
linear layer with output dimension 256." [0]
59-
[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
60-
"""
61-
62-
def __init__(
63-
self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256
64-
):
65-
super(BYOLProjectionHead, self).__init__(
66-
[
67-
(input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
68-
(hidden_dim, output_dim, None, None),
69-
]
70-
)
71-
72-
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
73-
74-
def preprocess_step(self, x: Tensor) -> Tensor:
75-
return self.avgpool(x).flatten(start_dim=1)
76-
77-
78-
class BYOLPredictionHead(ProjectionHead):
79-
"""Prediction head used for BYOL.
80-
"This MLP consists in a linear layer with output size 4096 followed by
81-
batch normalization, rectified linear units (ReLU), and a final
82-
linear layer with output dimension 256." [0]
83-
[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
84-
"""
85-
86-
def __init__(
87-
self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256
88-
):
89-
super(BYOLPredictionHead, self).__init__(
90-
[
91-
(input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
92-
(hidden_dim, output_dim, None, None),
93-
]
94-
)
95-
96-
97-
# --- Class implementation ----------------------------------------------------------
14+
from minerva.models.nets.mlp import MLP
15+
from torch.optim import Optimizer
16+
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone
9817

9918

10019
class BYOL(L.LightningModule):
101-
"""A Bootstrap Your Own Latent (BYOL) model for self-supervised learning.
20+
"""Bootstrap Your Own Latent (BYOL) model for self-supervised learning.
10221
10322
References
10423
----------
10524
Grill, J., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., ... & Valko, M. (2020).
106-
"Bootstrap your own latent-a new approach to self-supervised learning." Advances in neural information processing systems, 33, 21271-21284.
25+
"Bootstrap your own latent - a new approach to self-supervised learning." Advances in Neural Information Processing Systems, 33, 21271-21284.
10726
"""
10827

10928
def __init__(
11029
self,
11130
backbone: Optional[nn.Module] = None,
112-
learning_rate: float = 0.025,
113-
schedule: int = 90000,
31+
projection_head: Optional[nn.Module] = None,
32+
prediction_head: Optional[nn.Module] = None,
33+
learning_rate: Optional[float] = 1e-3,
34+
schedule: Optional[int] = 90000,
35+
criterion: Optional[Optimizer] = None,
11436
):
11537
"""
11638
Initializes the BYOL model.
11739
11840
Parameters
11941
----------
120-
backbone: Optional[nn.Module]
121-
The backbone network for feature extraction. Defaults to ResNet18.
122-
learning_rate: float
123-
The learning rate for the optimizer. Defaults to 0.025.
124-
schedule: int
42+
backbone : Optional[nn.Module]
43+
The backbone network for feature extraction. Defaults to DeepLabV3Backbone.
44+
projection_head : Optional[nn.Module]
45+
Optional custom projection head module. If None, a default MLP-based projection head is used.
46+
prediction_head : Optional[nn.Module]
47+
Optional custom prediction head module. If None, a default MLP-based prediction head is used.
48+
learning_rate : float
49+
The learning rate for the optimizer. Defaults to 1e-3.
50+
schedule : int
12551
The total number of steps for cosine decay scheduling. Defaults to 90000.
52+
criterion : Optional[Optimizer]
53+
Loss function to use. Defaults to NegativeCosineSimilarity.
12654
"""
12755
super().__init__()
128-
self.backbone = backbone or nn.Sequential(
129-
*list(torchvision.models.resnet18().children())[:-1]
130-
)
56+
self.backbone = backbone or DeepLabV3Backbone()
13157
self.learning_rate = learning_rate
132-
self.projection_head = BYOLProjectionHead(2048, 4096, 256)
133-
self.prediction_head = BYOLPredictionHead(256, 4096, 256)
58+
self.projection_head = projection_head or self._default_projection_head()
59+
self.prediction_head = prediction_head or self._default_prediction_head()
13460
self.backbone_momentum = copy.deepcopy(self.backbone)
13561
self.projection_head_momentum = copy.deepcopy(self.projection_head)
13662
self.deactivate_requires_grad(self.backbone_momentum)
13763
self.deactivate_requires_grad(self.projection_head_momentum)
138-
self.criterion = NegativeCosineSimilarity()
64+
self.criterion = criterion or NegativeCosineSimilarity()
13965
self.schedule_length = schedule
14066

67+
def _default_projection_head(self) -> nn.Module:
68+
"""Creates the default projection head used in BYOL."""
69+
return nn.Sequential(
70+
nn.AdaptiveAvgPool2d((1, 1)),
71+
nn.Flatten(start_dim=1),
72+
MLP(
73+
layer_sizes=[2048, 4096, 256],
74+
activation_cls=nn.ReLU,
75+
intermediate_ops=[nn.BatchNorm1d(4096), None],
76+
),
77+
)
78+
79+
def _default_prediction_head(self) -> nn.Module:
80+
"""Creates the default prediction head used in BYOL."""
81+
return MLP(
82+
layer_sizes=[256, 4096, 256],
83+
activation_cls=nn.ReLU,
84+
intermediate_ops=[nn.BatchNorm1d(4096), None],
85+
)
86+
14187
def forward(self, x: Tensor) -> Tensor:
14288
"""
14389
Forward pass for the BYOL model.

minerva/transforms/random_transform.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,28 @@ def __init__(
170170
num_samples: int = 1,
171171
seed: Optional[int] = None,
172172
):
173+
"""
174+
Randomly applies a rotation to the image with a specified probability.
173175
174-
super().__init__(num_samples, seed)
175-
self.prob = prob
176+
Parameters
177+
----------
178+
degrees : float
179+
Maximum absolute value of the rotation angle in degrees. The angle is sampled
180+
uniformly from [-degrees, +degrees].
181+
prob : float
182+
Probability that the rotation will be applied.
183+
num_samples : int, optional
184+
Number of samples to generate per call (for contrastive learning), default is 1.
185+
seed : int, optional
186+
Seed for the random number generator, useful for reproducibility.
187+
"""
188+
super().__init__(num_samples=num_samples, seed=seed)
176189
self.degrees = degrees
190+
self.prob = prob
177191

178192
def select_transform(self):
179-
180193
if self.rng.random() < self.prob:
181-
degrees = self.rng.uniform(-self.degrees, self.degrees)
182-
return Rotation(degrees=degrees)
194+
angle = self.rng.uniform(-self.degrees, self.degrees)
195+
return Rotation(degrees=angle)
183196
else:
184197
return EmptyTransform()

0 commit comments

Comments
 (0)