Skip to content

Commit fbccd40

Browse files
committed
fix serialization
1 parent c02000d commit fbccd40

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,6 @@ def update_config(self, save_directory: str):
790790
config_data = {}
791791

792792
# serialize configs into json
793-
breakpoint()
794793
qconfig_data = (
795794
self.quantization_config.model_dump(exclude=["quant_method"])
796795
if self.quantization_config is not None

src/compressed_tensors/quantization/quant_config.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import torch
1514
from collections import defaultdict
1615
from enum import Enum
1716
from typing import Annotated, Any, Dict, List, Optional, Set, Union
1817

18+
import torch
1919
from compressed_tensors.config import CompressionFormat
2020
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
2121
from compressed_tensors.quantization.quant_scheme import (
@@ -283,7 +283,17 @@ def model_dump(self, *args, **kwargs):
283283
# Call the parent dump first
284284
data = super().model_dump(*args, **kwargs)
285285

286-
# Convert any torch.dtype to string
286+
def _convert_dtypes_in_dict(d):
287+
for k, v in d.items():
288+
if isinstance(v, torch.dtype):
289+
if k == "zp_dtype" and d.get("symmetric"):
290+
d[k] = None
291+
else:
292+
d[k] = str(v).replace("torch.", "")
293+
elif isinstance(v, dict):
294+
_convert_dtypes_in_dict(v)
295+
return d
296+
287297
schemes = ["config_groups", "kv_cache_scheme"]
288298
for scheme in schemes:
289299
if data.get(scheme) is not None:
@@ -294,11 +304,8 @@ def model_dump(self, *args, **kwargs):
294304

295305
args = [weight, input, output]
296306
for arg in args:
297-
for key, value in arg.items():
298-
if isinstance(value, torch.dtype):
299-
data[key] = str(value).replace("torch.", "")
300-
301-
breakpoint()
307+
if arg is not None:
308+
_convert_dtypes_in_dict(arg)
302309
return data
303310

304311
# TODO set `extra="forbid"` when upstream transformers is compatible

0 commit comments

Comments
 (0)