Skip to content

Commit 98c4b07

Browse files
anwai98caroteuconstantinpape
authored
Add SSF method (#727)
Add SSF method for PEFT --------- Co-authored-by: Anwai Archit <anwai.archit@gmail.com> --------- Co-authored-by: Carolin <carolin.teuber@stud.uni-goettingen.de> Co-authored-by: Carolin Teuber <115626873+caroteu@users.noreply.github.com> Co-authored-by: Constantin Pape <constantin.pape@informatik.uni-goettingen.de>
1 parent 0a62171 commit 98c4b07

File tree

3 files changed

+80
-12
lines changed

3 files changed

+80
-12
lines changed

micro_sam/models/peft_sam.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,54 @@ def forward(self, x):
108108
return qkv
109109

110110

111+
class ScaleShiftLayer(nn.Module):
112+
def __init__(self, layer, dim):
113+
super().__init__()
114+
self.layer = layer
115+
self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
116+
self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
117+
layer = self
118+
119+
def forward(self, x):
120+
x = self.layer(x)
121+
assert self.scale.shape == self.shift.shape
122+
if x.shape[-1] == self.scale.shape[0]:
123+
return x * self.scale + self.shift
124+
elif x.shape[1] == self.scale.shape[0]:
125+
return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
126+
else:
127+
raise ValueError('Input tensors do not match the shape of the scale factors.')
128+
129+
130+
class SSFSurgery(nn.Module):
131+
"""Operates on all layers in the transformer block for adding learnable scale and shift parameters.
132+
133+
Args:
134+
rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
135+
block: The chosen attention blocks for implementing ssf.
136+
dim: The input dimensions determining the shape of scale and shift parameters.
137+
"""
138+
def __init__(self, rank: int, block: nn.Module):
139+
super().__init__()
140+
self.block = block
141+
142+
# If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
143+
if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers.
144+
block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
145+
block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
146+
block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
147+
block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
148+
block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
149+
block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
150+
151+
# If we get the embedding block, add one ScaleShiftLayer
152+
elif hasattr(block, "patch_embed"):
153+
block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
154+
155+
def forward(self, x):
156+
return x
157+
158+
111159
class SelectiveSurgery(nn.Module):
112160
"""Base class for selectively allowing gradient updates for certain parameters.
113161
"""
@@ -254,8 +302,10 @@ def __init__(
254302
super().__init__()
255303

256304
assert rank > 0
257-
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, AdaptFormer]), (
258-
"Invalid PEFT module")
305+
306+
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
307+
"Invalid PEFT module"
308+
)
259309

260310
if attention_layers_to_update:
261311
self.peft_layers = attention_layers_to_update
@@ -269,17 +319,19 @@ def __init__(
269319
for param in model.image_encoder.parameters():
270320
param.requires_grad = False
271321

322+
# Add scale and shift parameters to the patch embedding layers.
323+
if issubclass(self.peft_module, SSFSurgery):
324+
self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
325+
272326
for t_layer_i, blk in enumerate(model.image_encoder.blocks):
273327
# If we only want specific layers with PEFT instead of all
274328
if t_layer_i not in self.peft_layers:
275329
continue
276330

277331
if issubclass(self.peft_module, SelectiveSurgery):
278-
peft_block = self.peft_module(block=blk)
332+
self.peft_blocks.append(self.peft_module(block=blk))
279333
else:
280-
peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)
281-
282-
self.peft_blocks.append(peft_block)
334+
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
283335

284336
self.peft_blocks = nn.ModuleList(self.peft_blocks)
285337

micro_sam/training/training.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,6 @@ def set_description(self, desc, **kwargs):
146146
self._signals.pbar_description.emit(desc)
147147

148148

149-
def _count_parameters(model_parameters):
150-
params = sum(p.numel() for p in model_parameters if p.requires_grad)
151-
params = params / 1e6
152-
print(f"The number of trainable parameters for the provided model is {round(params, 2)}M")
153-
154-
155149
@contextmanager
156150
def _filter_warnings(ignore_warnings):
157151
if ignore_warnings:
@@ -163,6 +157,12 @@ def _filter_warnings(ignore_warnings):
163157
yield
164158

165159

160+
def _count_parameters(model_parameters):
161+
params = sum(p.numel() for p in model_parameters if p.requires_grad)
162+
params = params / 1e6
163+
print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
164+
165+
166166
def train_sam(
167167
name: str,
168168
model_type: str,
@@ -249,6 +249,7 @@ def train_sam(
249249
peft_kwargs=peft_kwargs,
250250
**model_kwargs
251251
)
252+
252253
# This class creates all the training data for a batch (inputs, prompts and labels).
253254
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
254255

test/test_models/test_peft_sam.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,27 @@ def test_bias_layer_peft_sam(self):
7878
masks = output[0]["masks"]
7979
self.assertEqual(masks.shape, expected_shape)
8080

81+
def test_ssf_peft_sam(self):
82+
from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery
83+
84+
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
85+
peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery)
86+
87+
shape = (3, 1024, 1024)
88+
expected_shape = (1, 3, 1024, 1024)
89+
with torch.no_grad():
90+
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
91+
output = peft_sam(batched_input, multimask_output=True)
92+
masks = output[0]["masks"]
93+
self.assertEqual(masks.shape, expected_shape)
94+
8195
def test_adaptformer_peft_sam(self):
8296
from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer
8397

8498
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
8599
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5)
86100

101+
87102
shape = (3, 1024, 1024)
88103
expected_shape = (1, 3, 1024, 1024)
89104
with torch.no_grad():

0 commit comments

Comments
 (0)