Skip to content

Commit 3e22bac

Browse files
authored
Merge branch 'main' into attn_quant
2 parents 5d13e2b + 2a59554 commit 3e22bac

File tree

1 file changed

+15
-10
lines changed
  • src/llmcompressor/modifiers/pruning/magnitude

1 file changed

+15
-10
lines changed

src/llmcompressor/modifiers/pruning/magnitude/base.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import warnings
12
from typing import Any, Dict, List, Union
23

4+
from pydantic import field_validator
5+
36
from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State
47
from llmcompressor.modifiers import Modifier
58
from llmcompressor.modifiers.pruning.helpers import (
@@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
2528
update_scheduler: str = "cubic"
2629
scheduler_args: Dict[str, Any] = {}
2730
mask_structure: str = "unstructured"
28-
leave_enabled: bool = True
31+
leave_enabled: bool = False
2932
apply_globally: bool = False
3033

3134
parameterized_layers_: Dict[str, ModelParameterizedLayer] = None
@@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
3538
mask_creator_function_: MaskCreatorType = None
3639
current_sparsity_: float = None
3740

41+
@field_validator("leave_enabled")
42+
def validate_leave_enabled(value: bool) -> bool:
43+
warnings.warn(
44+
"MagnitudePruningModifier.leave_enable has been deprecated",
45+
DeprecationWarning,
46+
)
47+
return False
48+
3849
def on_initialize(self, state: State, **kwargs) -> bool:
3950
if self.apply_globally:
4051
raise NotImplementedError("global pruning not implemented yet for PyTorch")
@@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
7586
return True
7687

7788
def on_finalize(self, state: State, **kwargs) -> bool:
78-
if not self.leave_enabled:
79-
for layer_param_name, _ in self.parameterized_layers_.items():
80-
self.remove_mask(layer_param_name)
89+
for layer_param_name, _ in self.parameterized_layers_.items():
90+
self.remove_mask(layer_param_name)
8191

8292
return True
8393

@@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
119129
self._update_masks(event)
120130

121131
def on_end(self, state: State, event: Event, **kwargs):
122-
if not self.leave_enabled:
123-
self.disable_masks()
124-
125-
def on_event(self, state: State, event: Event, **kwargs):
126-
if event.current_index >= self.end and self.leave_enabled:
127-
self._update_masks(event)
132+
self.disable_masks()
128133

129134
def _update_masks(self, event: Event):
130135
if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks:

0 commit comments

Comments
 (0)