Skip to content

Commit c02000d

Browse files
committed
update
1 parent de9f16a commit c02000d

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ def update_config(self, save_directory: str):
790790
config_data = {}
791791

792792
# serialize configs into json
793+
breakpoint()
793794
qconfig_data = (
794795
self.quantization_config.model_dump(exclude=["quant_method"])
795796
if self.quantization_config is not None

src/compressed_tensors/quantization/quant_config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import torch
1515
from collections import defaultdict
1616
from enum import Enum
1717
from typing import Annotated, Any, Dict, List, Optional, Set, Union
@@ -279,5 +279,27 @@ def requires_calibration_data(self):
279279

280280
return False
281281

282+
def model_dump(self, *args, **kwargs):
283+
# Call the parent dump first
284+
data = super().model_dump(*args, **kwargs)
285+
286+
# Convert any torch.dtype to string
287+
schemes = ["config_groups", "kv_cache_scheme"]
288+
for scheme in schemes:
289+
if data.get(scheme) is not None:
290+
for _, v in data[scheme].items():
291+
weight = v.get("weights")
292+
input = v.get("input_activations")
293+
output = v.get("output_activations")
294+
295+
args = [weight, input, output]
296+
for arg in args:
297+
for key, value in arg.items():
298+
if isinstance(value, torch.dtype):
299+
data[key] = str(value).replace("torch.", "")
300+
301+
breakpoint()
302+
return data
303+
282304
# TODO set `extra="forbid"` when upstream transformers is compatible
283305
model_config = ConfigDict(extra="ignore")

0 commit comments

Comments
 (0)