@@ -14,8 +14,7 @@ def get_lr_decay_parameters(model: nn.Module, learning_rate: float, lr_multiplie
14
14
groups: {"encoder": 0.1 ,"encoder.layer2": 0.2}
15
15
"""
16
16
custom_lr_parameters = dict (
17
- (group_name , {"params" : [], "lr" : learning_rate * lr_factor })
18
- for (group_name , lr_factor ) in lr_multipliers .items ()
17
+ (group_name , {"params" : [], "lr" : learning_rate * lr_factor }) for (group_name , lr_factor ) in lr_multipliers .items ()
19
18
)
20
19
custom_lr_parameters ["default" ] = {"params" : [], "lr" : learning_rate }
21
20
@@ -46,7 +45,7 @@ def get_optimizable_parameters(model: nn.Module) -> Iterator[nn.Parameter]:
46
45
return filter (lambda x : x .requires_grad , model .parameters ())
47
46
48
47
49
- def freeze_model (module : nn .Module , freeze_parameters : Optional [bool ] = True , freeze_bn : Optional [bool ] = True ):
48
+ def freeze_model (module : nn .Module , freeze_parameters : Optional [bool ] = True , freeze_bn : Optional [bool ] = True ) -> nn . Module :
50
49
"""
51
50
Change 'requires_grad' value for module and it's child modules and
52
51
optionally freeze batchnorm modules.
@@ -70,3 +69,5 @@ def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, fr
70
69
for m in module .modules ():
71
70
if isinstance (m , bn_types ):
72
71
module .track_running_stats = not freeze_bn
72
+
73
+ return module
0 commit comments