Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 579d201

Browse files
committed
Merge branch 'sa/quant_mod_refactor' of github.com:neuralmagic/sparseml into sa/quant_mod_refactor
2 parents bf7d0f6 + 90795bd commit 579d201

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1875
-970
lines changed

.github/workflows/test-check.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ jobs:
165165
- name: "Clean sparsezoo directory"
166166
run: rm -r sparsezoo/
167167
- name: "⚙️ Install dependencies"
168-
run: pip3 install .[dev,torchvision,onnxruntime]
168+
run: pip3 install .[dev,torchvision,onnxruntime,transformers]
169169
- name: "🔬 Running pytorch tests"
170170
run: make test TARGETS=pytorch
171171
compat-pytorch-1_9-pytorch-tests:
@@ -194,7 +194,7 @@ jobs:
194194
- name: "Clean sparsezoo directory"
195195
run: rm -r sparsezoo/
196196
- name: "⚙️ Install dependencies"
197-
run: pip3 install .[dev,torchvision,onnxruntime] torch==1.9.1
197+
run: pip3 install .[dev,torchvision,onnxruntime,transformers]
198198
- name: "🔬 Running pytorch tests"
199199
run: make test TARGETS=pytorch
200200
compat-pytorch-1_9-onnx-tests:

integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer):
3030

3131
The new `SFTTrainer` class can now apply SparseML recipes and modifiers during
3232
supervised finetuning, will full support for all of the original TRL features. The full
33-
class is defined in [sft_trainer.py](sft_trainer.py) and requires very minimal
33+
class is defined in the script `sft_trainer.py` and requires very minimal
3434
additional code: just a dataset load override to support passing in tokenized datasets
3535
to the Trainer.
3636

3737
### Examples
3838

39-
[ex_trl_sft_data.py](ex_trl_sft_data.py): finetunes a 50% sparse Llama-7b model,
39+
* Script `ex_trl_sft_data.py`: finetunes a 50% sparse Llama-7b model,
4040
using TRL's dataset preprocessing. Sparsity is maintained throughout training by
4141
applying a `ConstantPruningModifier` recipe to the `SFTTrainer`
4242

43-
[ex_trl_distillation.py](ex_trl_distillation.py): finetunes a 50% sparse Llama-7b
43+
* Script `ex_trl_distillation.py`: finetunes a 50% sparse Llama-7b
4444
model using knowledge distillation from a dense Llama-7b model. Sparsity is maintained
4545
throughout training with a `ConstantPruningModifier` and layer-wise knowledge
4646
distillation is handled by the `OutputDistillationModifier`

src/sparseml/core/recipe/recipe.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -580,15 +580,23 @@ def _get_yaml_dict(self) -> Dict[str, Any]:
580580
# populate stages
581581
stages = original_recipe_dict["stages"]
582582
for stage_name, stage_list in stages.items():
583-
# stage is always a list of size 1
584-
stage = stage_list[0]
585-
stage_dict = get_yaml_serializable_stage_dict(modifiers=stage["modifiers"])
583+
for idx, stage in enumerate(stage_list):
584+
if len(stage_list) > 1:
585+
# resolve name clashes caused by combining recipes with
586+
# duplicate stage names
587+
final_stage_name = f"{stage_name}_{idx}"
588+
else:
589+
final_stage_name = stage_name
590+
stage_dict = get_yaml_serializable_stage_dict(
591+
modifiers=stage["modifiers"]
592+
)
593+
594+
# infer run_type from stage
595+
if run_type := stage.get("run_type"):
596+
stage_dict["run_type"] = run_type
586597

587-
# infer run_type from stage
588-
if run_type := stage.get("run_type"):
589-
stage_dict["run_type"] = run_type
598+
yaml_recipe_dict[final_stage_name] = stage_dict
590599

591-
yaml_recipe_dict[stage_name] = stage_dict
592600
return yaml_recipe_dict
593601

594602

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from copy import deepcopy
16-
17-
from torch import nn
18-
19-
from sparseml.transformers.sparsification.modification import modify_model
20-
from sparseml.transformers.sparsification.modification.modification_objects import (
21-
QATLinear,
22-
)
23-
24-
25-
def test_modifying_mobilebert(mobilebert_model):
26-
27-
mobilebert_ = deepcopy(mobilebert_model)
28-
mobilebert = modify_model(mobilebert_model)
29-
30-
assert isinstance(mobilebert_.embeddings.embedding_transformation, nn.Linear)
31-
assert isinstance(mobilebert.embeddings.embedding_transformation, QATLinear)
14+
# flake8: noqa
15+
from .modify_model import modify_model

src/sparseml/transformers/sparsification/modification/modification_objects.py renamed to src/sparseml/modifiers/quantization/modification/modification_objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""
1616
Set of helper objects that are used to modify
17-
the HuggingFace transformer models
17+
the quantized models
1818
"""
1919

2020
import torch

src/sparseml/transformers/sparsification/modification/modify_model.py renamed to src/sparseml/modifiers/quantization/modification/modify_model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,34 @@
1515
import logging
1616
import os
1717

18-
import torch
19-
20-
from sparseml.transformers.sparsification.modification.registry import (
21-
ModificationRegistry,
22-
)
18+
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
2319

2420

2521
_LOGGER = logging.getLogger(__name__)
2622

2723

28-
def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Module:
24+
def modify_model(
25+
model: "torch.nn.Module", disable: bool = False # noqa: F821
26+
) -> "torch.nn.Module": # noqa: F821
2927
"""
30-
Modify the original transformers model so that it is
31-
compatible with the SparseML library.
28+
Modify the original model so that it is
29+
compatible with the quantization format required by the
30+
SparseML library.
31+
3232
The model will be modified, if there exist a modification
3333
function for the model in the registry of modifications.
3434
Otherwise, the original model will be returned.
3535
36-
:param model: The original HuggingFace transformers model
37-
:return: The potentially modified model
36+
:param model: The original model to be modified
37+
:param disable: If True, the modification will be disabled
38+
:return: The potentially modified model to support
39+
SparseML quantization
3840
"""
3941
model_name = model.__class__.__name__
40-
NM_DISABLE_TRANSFORMERS_MODIFICATION = os.environ.get(
41-
"NM_DISABLE_TRANSFORMERS_MODIFICATION", "False"
42+
NM_DISABLE_QUANTIZATION_MODIFICATION = os.environ.get(
43+
"NM_DISABLE_QUANTIZATION_MODIFICATION", "False"
4244
).lower() in ["true", "1"]
45+
4346
try:
4447
modification_func = ModificationRegistry.get_value_from_registry(model_name)
4548
except KeyError:
@@ -50,7 +53,7 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
5053
)
5154
return model
5255

53-
if NM_DISABLE_TRANSFORMERS_MODIFICATION:
56+
if NM_DISABLE_QUANTIZATION_MODIFICATION:
5457
_LOGGER.debug(
5558
"Application of the modification function to model "
5659
"disabled through the environment variable."
@@ -65,6 +68,6 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
6568
return model
6669

6770
_LOGGER.info(
68-
f"Modifying the model {model_name} to be compatible with SparseML library"
71+
f"Modifying the model {model_name} to be compatible with SparseML quantization"
6972
)
7073
return modification_func(model)

src/sparseml/transformers/sparsification/modification/registry.py renamed to src/sparseml/modifiers/quantization/modification/registry.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from sparseml.transformers.sparsification.modification.base import (
16-
check_transformers_version,
17-
)
1814
from sparsezoo.utils.registry import RegistryMixin
1915

2016

2117
class ModificationRegistry(RegistryMixin):
2218
"""
2319
A registry for modification functions that can be applied to models
24-
so that they can be used in the context of sparseml.transformers
20+
so that they can be compatible with the quantization format required by the
21+
SparseML library.
2522
"""
26-
27-
@classmethod
28-
def get_value_from_registry(cls, name: str):
29-
"""
30-
Extends the base class method to check the transformers version after
31-
successfully retrieving the value from the registry. The motivation is
32-
to ensure that the transformers version falls within the supported range
33-
before we proceed with model modification.
34-
"""
35-
retrieved_value = super().get_value_from_registry(name)
36-
check_transformers_version()
37-
return retrieved_value

src/sparseml/modifiers/quantization/pytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from sparseml.core import Event, EventType, State
2222
from sparseml.modifiers.quantization.base import QuantizationModifier
23+
from sparseml.modifiers.quantization.modification import modify_model
2324
from sparseml.modifiers.quantization.utils.helpers import (
2425
configure_module_bn_wrappers,
2526
freeze_bn_stats,
@@ -73,11 +74,16 @@ def __init__(self, **kwargs):
7374

7475
def on_initialize_structure(self, state: State, **kwargs):
7576
module = state.model.model
77+
# before the structure is modified to support quantization,
78+
# we need to potentially modify the model architecture
79+
module = modify_model(module)
7680
self._enable_module_qat(module)
7781
state.model.model.apply(torch.quantization.disable_observer)
7882

7983
def on_initialize(self, state: State, **kwargs) -> bool:
8084
raise_if_torch_quantization_not_available()
85+
module = state.model.model
86+
module = modify_model(module)
8187
if self.end and self.end != -1:
8288
raise ValueError(
8389
"end_epoch is disabled for QuantizationModifier and can only be set to"

src/sparseml/modifiers/smoothquant/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from dataclasses import dataclass
1717
from typing import Dict, Generic, List, Optional, Tuple, TypeVar
1818

19-
from pydantic import Field
20-
2119
from sparseml.core import Modifier
2220
from sparseml.core.model import ModifiableModel
2321
from sparseml.core.model.base import LT
@@ -98,7 +96,7 @@ class SmoothQuantModifier(Modifier):
9896
use the whole dataset
9997
"""
10098

101-
smoothing_strength: float = Field(validation_alias="alpha", default=0.5)
99+
smoothing_strength: float = 0.5
102100
mappings: List[Tuple]
103101
ignore: Optional[List[str]] = None
104102
num_calibration_steps: Optional[int] = None

src/sparseml/transformers/sparsification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# flake8: noqa
2121

22+
from .modification import *
2223
from .question_answering import *
2324
from .sparse_config import *
2425
from .sparse_model import *

0 commit comments

Comments
 (0)