@@ -108,6 +108,54 @@ def forward(self, x):
108108 return qkv
109109
110110
111+ class ScaleShiftLayer (nn .Module ):
112+ def __init__ (self , layer , dim ):
113+ super ().__init__ ()
114+ self .layer = layer
115+ self .scale = nn .Parameter (torch .normal (mean = 1.0 , std = 0.2 , size = (dim ,)))
116+ self .shift = nn .Parameter (torch .normal (mean = 0.0 , std = 0.2 , size = (dim ,)))
117+ layer = self
118+
119+ def forward (self , x ):
120+ x = self .layer (x )
121+ assert self .scale .shape == self .shift .shape
122+ if x .shape [- 1 ] == self .scale .shape [0 ]:
123+ return x * self .scale + self .shift
124+ elif x .shape [1 ] == self .scale .shape [0 ]:
125+ return x * self .scale .view (1 , - 1 , 1 , 1 ) + self .shift .view (1 , - 1 , 1 , 1 )
126+ else :
127+ raise ValueError ('Input tensors do not match the shape of the scale factors.' )
128+
129+
130+ class SSFSurgery (nn .Module ):
131+ """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
132+
133+ Args:
134+ rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
135+ block: The chosen attention blocks for implementing ssf.
136+ dim: The input dimensions determining the shape of scale and shift parameters.
137+ """
138+ def __init__ (self , rank : int , block : nn .Module ):
139+ super ().__init__ ()
140+ self .block = block
141+
142+ # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
143+ if hasattr (block , "attn" ): # the minimum assumption is to verify the attention layers.
144+ block .attn .qkv = ScaleShiftLayer (block .attn .qkv , block .attn .qkv .in_features * 3 )
145+ block .attn .proj = ScaleShiftLayer (block .attn .proj , block .attn .proj .in_features )
146+ block .mlp .lin1 = ScaleShiftLayer (block .mlp .lin1 , block .mlp .lin1 .out_features )
147+ block .mlp .lin2 = ScaleShiftLayer (block .mlp .lin2 , block .mlp .lin2 .out_features )
148+ block .norm1 = ScaleShiftLayer (block .norm1 , block .norm1 .normalized_shape [0 ])
149+ block .norm2 = ScaleShiftLayer (block .norm2 , block .norm2 .normalized_shape [0 ])
150+
151+ # If we get the embedding block, add one ScaleShiftLayer
152+ elif hasattr (block , "patch_embed" ):
153+ block .proj = ScaleShiftLayer (block .proj , block .proj .out_channels )
154+
155+ def forward (self , x ):
156+ return x
157+
158+
111159class SelectiveSurgery (nn .Module ):
112160 """Base class for selectively allowing gradient updates for certain parameters.
113161 """
@@ -254,8 +302,10 @@ def __init__(
254302 super ().__init__ ()
255303
256304 assert rank > 0
257- assert issubclass (peft_module , Union [LoRASurgery , FacTSurgery , SelectiveSurgery , AdaptFormer ]), (
258- "Invalid PEFT module" )
305+
306+ assert issubclass (peft_module , Union [LoRASurgery , FacTSurgery , SelectiveSurgery , SSFSurgery , AdaptFormer ]), (
307+ "Invalid PEFT module"
308+ )
259309
260310 if attention_layers_to_update :
261311 self .peft_layers = attention_layers_to_update
@@ -269,17 +319,19 @@ def __init__(
269319 for param in model .image_encoder .parameters ():
270320 param .requires_grad = False
271321
322+ # Add scale and shift parameters to the patch embedding layers.
323+ if issubclass (self .peft_module , SSFSurgery ):
324+ self .peft_blocks .append (self .peft_module (rank = rank , block = model .image_encoder .patch_embed ))
325+
272326 for t_layer_i , blk in enumerate (model .image_encoder .blocks ):
273327 # If we only want specific layers with PEFT instead of all
274328 if t_layer_i not in self .peft_layers :
275329 continue
276330
277331 if issubclass (self .peft_module , SelectiveSurgery ):
278- peft_block = self .peft_module (block = blk )
332+ self .peft_blocks . append ( self . peft_module (block = blk ) )
279333 else :
280- peft_block = self .peft_module (rank = rank , block = blk , ** module_kwargs )
281-
282- self .peft_blocks .append (peft_block )
334+ self .peft_blocks .append (self .peft_module (rank = rank , block = blk , ** module_kwargs ))
283335
284336 self .peft_blocks = nn .ModuleList (self .peft_blocks )
285337
0 commit comments