1
+ import warnings
1
2
from typing import Any , Dict , List , Union
2
3
4
+ from pydantic import field_validator
5
+
3
6
from llmcompressor .core import Event , EventType , ModelParameterizedLayer , State
4
7
from llmcompressor .modifiers import Modifier
5
8
from llmcompressor .modifiers .pruning .helpers import (
@@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
25
28
update_scheduler : str = "cubic"
26
29
scheduler_args : Dict [str , Any ] = {}
27
30
mask_structure : str = "unstructured"
28
- leave_enabled : bool = True
31
+ leave_enabled : bool = False
29
32
apply_globally : bool = False
30
33
31
34
parameterized_layers_ : Dict [str , ModelParameterizedLayer ] = None
@@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
35
38
mask_creator_function_ : MaskCreatorType = None
36
39
current_sparsity_ : float = None
37
40
41
+ @field_validator ("leave_enabled" )
42
+ def validate_leave_enabled (value : bool ) -> bool :
43
+ warnings .warn (
44
+ "MagnitudePruningModifier.leave_enable has been deprecated" ,
45
+ DeprecationWarning ,
46
+ )
47
+ return False
48
+
38
49
def on_initialize (self , state : State , ** kwargs ) -> bool :
39
50
if self .apply_globally :
40
51
raise NotImplementedError ("global pruning not implemented yet for PyTorch" )
@@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
75
86
return True
76
87
77
88
def on_finalize (self , state : State , ** kwargs ) -> bool :
78
- if not self .leave_enabled :
79
- for layer_param_name , _ in self .parameterized_layers_ .items ():
80
- self .remove_mask (layer_param_name )
89
+ for layer_param_name , _ in self .parameterized_layers_ .items ():
90
+ self .remove_mask (layer_param_name )
81
91
82
92
return True
83
93
@@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
119
129
self ._update_masks (event )
120
130
121
131
def on_end (self , state : State , event : Event , ** kwargs ):
122
- if not self .leave_enabled :
123
- self .disable_masks ()
124
-
125
- def on_event (self , state : State , event : Event , ** kwargs ):
126
- if event .current_index >= self .end and self .leave_enabled :
127
- self ._update_masks (event )
132
+ self .disable_masks ()
128
133
129
134
def _update_masks (self , event : Event ):
130
135
if event .type_ == EventType .OPTIM_PRE_STEP and not self ._use_hooks :
0 commit comments