1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import torch
1514from collections import defaultdict
1615from enum import Enum
1716from typing import Annotated , Any , Dict , List , Optional , Set , Union
1817
18+ import torch
1919from compressed_tensors .config import CompressionFormat
2020from compressed_tensors .quantization .quant_args import DynamicType , QuantizationArgs
2121from compressed_tensors .quantization .quant_scheme import (
@@ -283,7 +283,17 @@ def model_dump(self, *args, **kwargs):
283283 # Call the parent dump first
284284 data = super ().model_dump (* args , ** kwargs )
285285
286- # Convert any torch.dtype to string
286+ def _convert_dtypes_in_dict (d ):
287+ for k , v in d .items ():
288+ if isinstance (v , torch .dtype ):
289+ if k == "zp_dtype" and d .get ("symmetric" ):
290+ d [k ] = None
291+ else :
292+ d [k ] = str (v ).replace ("torch." , "" )
293+ elif isinstance (v , dict ):
294+ _convert_dtypes_in_dict (v )
295+ return d
296+
287297 schemes = ["config_groups" , "kv_cache_scheme" ]
288298 for scheme in schemes :
289299 if data .get (scheme ) is not None :
@@ -294,11 +304,8 @@ def model_dump(self, *args, **kwargs):
294304
295305 args = [weight , input , output ]
296306 for arg in args :
297- for key , value in arg .items ():
298- if isinstance (value , torch .dtype ):
299- data [key ] = str (value ).replace ("torch." , "" )
300-
301- breakpoint ()
307+ if arg is not None :
308+ _convert_dtypes_in_dict (arg )
302309 return data
303310
304311 # TODO set `extra="forbid"` when upstream transformers is compatible
0 commit comments