Skip to content

Commit bad2cba

Browse files
committed
freeze_model now returns input module
1 parent 67b6af9 commit bad2cba

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pytorch_toolbelt/optimization/functional.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ def get_lr_decay_parameters(model: nn.Module, learning_rate: float, lr_multiplie
1414
groups: {"encoder": 0.1 ,"encoder.layer2": 0.2}
1515
"""
1616
custom_lr_parameters = dict(
17-
(group_name, {"params": [], "lr": learning_rate * lr_factor})
18-
for (group_name, lr_factor) in lr_multipliers.items()
17+
(group_name, {"params": [], "lr": learning_rate * lr_factor}) for (group_name, lr_factor) in lr_multipliers.items()
1918
)
2019
custom_lr_parameters["default"] = {"params": [], "lr": learning_rate}
2120

@@ -46,7 +45,7 @@ def get_optimizable_parameters(model: nn.Module) -> Iterator[nn.Parameter]:
4645
return filter(lambda x: x.requires_grad, model.parameters())
4746

4847

49-
def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True):
48+
def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True) -> nn.Module:
5049
"""
5150
Change 'requires_grad' value for module and it's child modules and
5251
optionally freeze batchnorm modules.
@@ -70,3 +69,5 @@ def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, fr
7069
for m in module.modules():
7170
if isinstance(m, bn_types):
7271
module.track_running_stats = not freeze_bn
72+
73+
return module

0 commit comments

Comments
 (0)