diff --git a/.pylintrc b/.pylintrc index 4effcbf7..da95ef1d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -69,7 +69,8 @@ ignored-modules=gptqmodel, llmcompressor, cutlass_mm, pygraphviz, - matplotlib + matplotlib, + compressed_tensors # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). diff --git a/fms_mo/dq.py b/fms_mo/dq.py index eb49bc30..2b4bdfa5 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -47,6 +47,12 @@ get_act_scales_1gpu, ) from fms_mo.utils.aiu_utils import save_for_aiu +from fms_mo.utils.dq_inf import ( + check_quantization_setting, + convert_fp8_vllm_to_fms_mo, + load_inference_qconfig_file, + save_vllm_fp8, +) from fms_mo.utils.dq_utils import config_quantize_smooth_layers from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU from fms_mo.utils.utils import patch_torch_bmm, prepare_input @@ -134,7 +140,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Initialized model is: \n {model}") logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") - qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + + inference_only = check_quantization_setting(model) + + if not inference_only: + logger.info("quantization mode activated, initalizing the qcfg file ") + qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + else: + logger.info("inference mode activated") + qcfg = load_inference_qconfig_file(model_args, fms_mo_args) + + if inference_only: + model = convert_fp8_vllm_to_fms_mo(model=model) model_size = model_size_Wb(model, unit="GB") gpu_mem_util_per = model_size / total_gpu_memory @@ -159,7 +176,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["model"] = model_args.model_name_or_path # config layers to skip, smooth scale - config_quantize_smooth_layers(qcfg) + if not inference_only: + config_quantize_smooth_layers(qcfg) use_dynamo = True # use dynamo as default unless really needed, False -> fallback to TorchScript tracing @@ -178,6 +196,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["model"] = model_args.model_name_or_path qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg qcfg["plotsvg"] = False + qcfg["output_folder"] = opt_args.output_dir calibration_dataset = load_from_disk(data_args.training_data_path) calibration_dataset = calibration_dataset.with_format("torch") @@ -190,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - if qcfg["smoothq"]: + if not inference_only and qcfg["smoothq"]: scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") if qcfg.get("act_scale_path", None): # user provided a scale file (or a dir) @@ -229,48 +248,56 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Quantized model {model}") logger.info("==" * 20) - if qcfg["smoothq"]: - logger.info("Starting to apply smooth scale") - dq_llm(model, act_scales, qcfg) - logger.info("Finished applying smooth scale") + if not inference_only: + if qcfg["smoothq"]: + logger.info("Starting to apply smooth scale") + dq_llm(model, act_scales, qcfg) + logger.info("Finished applying smooth scale") - if qcfg["qmodel_calibration_new"] > 0: - logger.info("Starting to calibrate activation clip_val") - if qcfg["large_model"]: - calibration_llm_1GPU_v2(qcfg, model, dq_dataloader) - else: - model.to("cuda") - pbar = tqdm( - dq_dataloader, - desc=" calibration after applying smoothq scale and before inference", - total=qcfg["qmodel_calibration_new"], + if qcfg["qmodel_calibration_new"] > 0: + logger.info("Starting to calibrate activation clip_val") + if qcfg["large_model"]: + calibration_llm_1GPU_v2(qcfg, model, dq_dataloader) + else: + model.to("cuda") + pbar = tqdm( + dq_dataloader, + desc=" calibration after applying smoothq scale and before inference", + total=qcfg["qmodel_calibration_new"], + ) + for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])): + data_mb = prepare_input(model.device, data_mb) + with patch_torch_bmm(qcfg): + model(**data_mb) + + if opt_args.save_ckpt_for_aiu: + logger.info( + f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}" + ) + save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True) + elif not opt_args.save_ckpt: + logger.info( + f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}" + ) + save_vllm_fp8(model, qcfg, tokenizer, opt_args.output_dir) + elif opt_args.save_ckpt: + logger.info( + f"Saving quantized model and tokenizer to {opt_args.output_dir}" + ) + model.save_pretrained(opt_args.output_dir, use_safetensors=True) + tokenizer.save_pretrained(opt_args.output_dir) + + if fms_mo_args.aiu_sim_triton: + # NOTE plz apply correct HW settings here, defaults are not real HW params + lower_qmodel_triton( + model, + use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False, + max_acc_bits=qcfg.get("max_acc_bits", 32), + num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), + chunk_size=qcfg.get("chunk_size", 32), # 1024 + clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", + # layer_to_exclude=["lm_head",] ) - for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])): - data_mb = prepare_input(model.device, data_mb) - with patch_torch_bmm(qcfg): - model(**data_mb) - - if opt_args.save_ckpt_for_aiu: - logger.info( - f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}" - ) - save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True) - elif opt_args.save_ckpt: - logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}") - model.save_pretrained(opt_args.output_dir, use_safetensors=True) - tokenizer.save_pretrained(opt_args.output_dir) - - if fms_mo_args.aiu_sim_triton: - # NOTE plz apply correct HW settings here, defaults are not real HW params - lower_qmodel_triton( - model, - use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False, - max_acc_bits=qcfg.get("max_acc_bits", 32), - num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), - chunk_size=qcfg.get("chunk_size", 32), # 1024 - clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", - # layer_to_exclude=["lm_head",] - ) if fms_mo_args.eval_ppl: path_test = Path(data_args.test_data_path) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 42e40b79..410d885f 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -391,11 +391,19 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): # For nn.Linear elif isinstance(module, nn.Linear): if module.__class__ != nn.Linear: - logger.warning( - f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." - "Please make sure it doesn't wrap BN and activ func." - "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." - ) + if available_packages["compressed_tensors"]: + # Third Party + import compressed_tensors + # checks if the layer is CompressedLinear. If it is a CompressedLinear layer, + # it does nothing. Otherwise, it throws the warning sign + if not isinstance( + module, compressed_tensors.linear.compressed_linear.CompressedLinear + ): + logger.warning( + f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." + "Please make sure it doesn't wrap BN and activ func. Otherwise" + "please create an equivalent Linear wrapper and change qcfg['mapping']." + ) QLin = mapping.get(nn.Linear, None) if QLin is None: @@ -571,6 +579,42 @@ def has_quantized_module(model): return any(isinstance(m, quantized_modules) for m in model.modules()) +def swap_qbmm(model: nn.Module, qcfg: dict): + """Go through all model.named_modules(), try to create an equivalent + Qbmm layer to replace each of the existing linear Bmm layers. + + Args: + model (nn.Module): input model to be "prepared" + qcfg (dict): quant config + + Returns: updated model is returned with the Qbmm added + + """ + + # Local + from fms_mo.modules import QBmm + + qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"]["which2patch_contextmanager"] + isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm" + for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items(): + mod_bmm_happened = model.get_submodule(mod_name) + for whichQBmm, ln in enumerate(line_nums, start=1): + nbits = qcfg[f"nbits_bmm{whichQBmm}"] + newQBmm = QBmm( + num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits, + num_bits_m2=nbits, + qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"], + qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"], + m1_unidirectional=(whichQBmm == 2), + m1_bounded=(whichQBmm == 2), # see Note 5 + m2_unidirectional=False, + m2_bounded=False, + replaceBmm=isbmm, + qcfg=qcfg, + ) + setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm) + + def qmodel_prep( model, dloader, @@ -657,6 +701,12 @@ def qmodel_prep( Returns: nn.Module: quantized model ready for further PTQ/QAT """ + if qcfg["fp8_inference"]: + if qcfg.get("QBmm"): + swap_qbmm(model, qcfg) + + model = q_any_net_5(model, qcfg, verbose=False) + return model sys.setrecursionlimit(4000) @@ -906,8 +956,10 @@ def qmodel_prep( model = torch.nn.parallel.DistributedDataParallel( model, device_ids=DPorDDPdevices ) - - qconfig_save(qcfg, fname="qcfg.json") + if qcfg["output_folder"] is None: + qconfig_save(qcfg, fname="qcfg.json") + else: + qconfig_save(qcfg, fname=qcfg["output_folder"] + "/qcfg.json") qcfg["tb_writer"] = tb_writer logger.info(f"--- Quantized model --- \n{model}\n") diff --git a/fms_mo/recipes/dq.json b/fms_mo/recipes/dq.json index be425998..70c7a87d 100644 --- a/fms_mo/recipes/dq.json +++ b/fms_mo/recipes/dq.json @@ -10,5 +10,7 @@ "eval_ckpt": true, "nbits_bmm1" : 32, "nbits_bmm2" : 32, - "nbits_kvcache" : 32 + "nbits_kvcache" : 32, + "fp8_inference": false, + "output_folder": null } \ No newline at end of file diff --git a/fms_mo/recipes/fp8_vllm_quantization_config.json b/fms_mo/recipes/fp8_vllm_quantization_config.json new file mode 100644 index 00000000..96b87619 --- /dev/null +++ b/fms_mo/recipes/fp8_vllm_quantization_config.json @@ -0,0 +1,44 @@ +{ +"quantization_config": { + "config_groups": { + "group_0": { + "input_activations": { + "actorder": null, + "block_structure": null, + "dynamic": true, + "group_size": null, + "num_bits": 8, + "observer": null, + "observer_kwargs": {}, + "strategy": "token", + "symmetric": true, + "type": "float" + }, + "output_activations": null, + "targets": [ + "Linear" + ], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": null, + "num_bits": 8, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "channel", + "symmetric": true, + "type": "float" + } + } + }, + "format": "float-quantized", + "global_compression_ratio": null, + "ignore": [ + "lm_head" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed" + } +} \ No newline at end of file diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index 95f38424..01e37377 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -209,6 +209,7 @@ class FMSMOArguments(TypeChecker): default=False, metadata={"help": "Apply recomputation during checkpoint saving for AIU."}, ) + override_qcfg_args: bool = field(default=False) @dataclass diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py new file mode 100644 index 00000000..9f44d313 --- /dev/null +++ b/fms_mo/utils/dq_inf.py @@ -0,0 +1,317 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluation utils for fast model loading and saving for FP8 DQ +""" + +# Standard +from typing import Any +import glob +import json +import logging +import os + +# Third Party +from huggingface_hub import save_torch_state_dict +from safetensors.torch import load_file +from torch import nn +import torch + +# Local +from fms_mo import qconfig_init +from fms_mo.quant.quantizers import to_fp8_scaled_perCh +from fms_mo.utils.qconfig_utils import get_recipe + +logger = logging.getLogger(__name__) + + +def check_quantization_setting(model: nn.Module) -> bool: + """ + function checks if the checkpoint is from fp8 quantization + """ + quant_config = None + if hasattr(model, "config"): + quant_config = model.config.to_dict().get("quantization_config", None) + if quant_config is None: + return False + + logger.info("Validating config settings") + if "quant_method" in quant_config.keys(): + if quant_config["quant_method"] == "compressed-tensors": + if quant_config["format"] != "float-quantized": + raise ValueError( + "The input activation and weight quantization dtypes are not supported" + ) + + if ( + quant_config["config_groups"]["group_0"]["input_activations"][ + "num_bits" + ] + != 8 + ): + raise ValueError( + "Only 8 bit FP input activation quantization is supported" + ) + + if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8: + raise ValueError("Only 8-bit FP weight quantization is supported") + + if quant_config["kv_cache_scheme"] is not None: + if quant_config["kv_cache_scheme"]["type"] is not float: + raise ValueError("The KV-Cache quantization dtype is not supported") + + if quant_config["kv_cache_scheme"]["num_bits"] != 8: + raise ValueError( + "Only 8-bit KV-Cache quantization dtype is supported" + ) + + return True + raise ValueError( + "The quantization method is not supported for inferencing." + "Only Fp8 quantization is supported" + ) + + raise ValueError( + "The quantization method is not found. Please check the config file" + ) + + +def load_inference_qconfig_file( + model_args: Any = None, fms_mo_args: Any = None +) -> dict[str, int | float | str]: + """ + Function to load the inference quantization config for fms_mo + """ + if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"): + if fms_mo_args.override_qcfg_args: + logger.info("qcfg file found and some parameters are being over-written") + qcfg = qconfig_init( + recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args + ) + else: + logger.info(f"loading quantization configuration from\ + {model_args.model_name_or_path + '/qcfg.json'}") + qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg") + else: + logger.info( + f"qcfg file not found in {model_args.model_name_or_path}," + "loading fms_mo_args and recipe" + ) + qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + qcfg = update_qcfg_from_model_config(model_args, qcfg) + qcfg["fp8_inference"] = True + + return qcfg + + +def update_qcfg_from_model_config( + model_args: Any = None, qcfg: dict = None +) -> dict[str, int | float | str]: + """ + function to update the default qcfg setting with settings in the model config file. + Important for the case where qcfg file does not exist. + """ + config = get_recipe(model_args.model_name_or_path + "/config") + if ( + config["quantization_config"]["config_groups"]["group_0"]["input_activations"][ + "strategy" + ] + == "token" + ): + qcfg["qa_mode"] = "fp8_e4m3_scale_perToken" + else: + raise ValueError("Only perToken Fp8 activation quantizer is supported") + + if ( + config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] + == "channel" + ): + qcfg["qw_mode"] = "fp8_e4m3_scale_perCh" + elif ( + config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] + == "tensor" + ): + qcfg["qw_mode"] = "fp8_e4m3_scale" + else: + raise ValueError( + "Only perChannel or pertensor FP8 quantizers are currently supported" + ) + + qcfg["smoothq"] = False + qcfg["nbits_a"] = config["quantization_config"]["config_groups"]["group_0"][ + "input_activations" + ]["num_bits"] + qcfg["nbits_w"] = config["quantization_config"]["config_groups"]["group_0"][ + "weights" + ]["num_bits"] + qcfg["torch_dtype"] = "float16" + if config["quantization_config"]["ignore"] != []: + qcfg["qskip_layer_name"] = config["quantization_config"]["ignore"] + qcfg["qskip_large_mag_layers"] = True + else: + qcfg["qskip_layer_name"] = [] + qcfg["qskip_large_mag_layers"] = False + return qcfg + + +def rename_fms_dict_to_vllm_dict( + model_dict: dict = None, +) -> tuple[dict[str, float | int], dict[str, float | int]]: + """ + Function to rename the dict in fms_mo format to vllm_format. + """ + st_dict = {} + fms_dict = {} + keys = model_dict.keys() + logger.info("WARNING: only static weights per-channel is supported at this time") + for k, v in model_dict.items(): + if ".weight" in k: + key = k.split("weight")[0] + if key + "quantize_weight.scale" in keys: + weight, scale = to_fp8_scaled_perCh(v, emulate=False) + st_dict[key + "weight"] = weight + st_dict[key + "weight_scale"] = 1 / scale + else: + st_dict[k] = v + else: + fms_dict[k] = v + return st_dict, fms_dict + + +def update_config( + model_config_file: dict = None, qcfg: dict = None +) -> dict[str, float | int | str]: + """ + Function to update the model config file with quantization configuration + """ + data = get_recipe("fp8_vllm_quantization_config") + if "perCh" not in qcfg["qw_mode"]: + data["quantization_config"]["config_groups"]["group_0"]["weights"] = ( + "{num_bits: 8, type: float, symmetric: true, strategy: tensor}" + ) + if qcfg["qskip_large_mag_layers"] is True: + data["quantization_config"]["ignore"] = qcfg["qskip_layer_name"] + model_config_file.update(data) + return model_config_file + + +def save_vllm_fp8( + model: nn.Module, qcfg: dict, tokenizer=None, folder: str = None +) -> None: + """ + Function to save fp8 DQ model in vllm fp8 format + """ + model_dict = model.state_dict() + vllm_dict, fms_dict = rename_fms_dict_to_vllm_dict(model_dict=model_dict) + config = model.config.to_dict() + config = update_config(config, qcfg) + + save_torch_state_dict(vllm_dict, folder) + save_torch_state_dict( + fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors" + ) + tokenizer.save_pretrained(folder) + + with open(folder + "/config.json", "w+", encoding="utf-8") as f: + json.dump(config, f, indent=4) + + +def convert_fms_mo_to_vllm_fp8_format( + checkpoint: str = None, folder: str = None +) -> None: + """ + Function to convert fp8 fms_mo DQ model checkpoint to vllm fp8 format + """ + folder = checkpoint + "/" + folder + if os.path.isdir(folder): + logger(f"The folder '{folder}' exists.") + else: + os.mkdir(folder) + logger(f"The folder '{folder}' created.") + + qcfg = get_recipe(checkpoint + "/qcfg") + config = get_recipe(checkpoint + "/config") + files = find_file_glob("model-*", checkpoint) + merged_files_dict = {} + + for file in files: + temp_dict = load_file(file) + merged_files_dict.update(temp_dict) + + vllm_dict, fms_dict = rename_fms_dict_to_vllm_dict(merged_files_dict) + config = update_config(config, qcfg) + + save_torch_state_dict(vllm_dict, folder) + save_torch_state_dict( + fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors" + ) + with open(folder + "/config.json", "w+", encoding="utf-8") as f: + json.dump(config, f, indent=4) + + +def find_file_glob(pattern: str, search_path: str) -> list[str]: + """ + Finds files matching a pattern within a directory and its subdirectories. + """ + # Use '**' for recursive search in modern Python versions (3.5+) + full_pattern = os.path.join(search_path, "**", pattern) + found_files = glob.glob(full_pattern, recursive=True) + return sorted(found_files) + + +def convert_fp8_vllm_dict_to_fms_mo_dict( + checkpoint: str = None, output_dir: str = None +) -> None: + """ + Function to help convert vllm fp8 checkpoint into fms_mo fp8 format + """ + merged_files_dict = {} + files = find_file_glob("model-*", checkpoint) + for file in files: + temp_dict = load_file(file) + merged_files_dict.update(temp_dict) + + fms_mo_dict = rename_vllm_dict_to_fms_mo(merged_files_dict) + save_torch_state_dict(fms_mo_dict, output_dir) + + +def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict[str, float | int | str]: + """ + Function to help rename vllm dict format to fms_mo dict format + """ + fms_mo_dict = {} + for k, v in vllm_dict.items(): + if "weight_scale" in k: + key = k.split("weight")[0] + fms_mo_dict[key + "weight"] = ( + vllm_dict[key + "weight"].to(torch.float16) * v + ) + fms_mo_dict[k] = v + else: + key = k.split("weight")[0] + if key + "weight_scale" not in vllm_dict.keys(): + fms_mo_dict[k] = v + return fms_mo_dict + + +def convert_fp8_vllm_to_fms_mo(model: nn.Module = None) -> nn.Module: + """ + Function to help convert fp8 vllm model dict format to fms_mo fp8 format + """ + model_dict = model.state_dict() + fms_dict = rename_vllm_dict_to_fms_mo(model_dict) + model = model.to(torch.float16) + model.load_state_dict(fms_dict) + return model diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index a695e39d..6298c353 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -42,6 +42,7 @@ "torchvision", "huggingface_hub", "torchao", + "compressed_tensors", ] available_packages = {} diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index c8e9a093..b479d302 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -150,6 +150,8 @@ def config_defaults() -> dict: "smoothq_scale_layers": [], "smoothq_act_scale_path": None, # Other vars + "fp8_inference": False, + "output_folder": None, "which2patch_contextmanager": None, "force_stop_if_qbmm_auto_check_failed": False, "world_size": max(1, torch.cuda.device_count()), @@ -299,6 +301,8 @@ def qconfig_init(recipe: str = None, args: Any = None, use_mx: bool = False) -> qcfg["w_init_method"] = "sawb" qcfg["a_init_method"] = "percentile" qcfg["clip_val_asst_percentile"] = (0.1, 99.9) + qcfg["fp8_inference"] = False + qcfg["output_folder"] = None # ways to control which layers to be quantized/skipped qcfg["qlayer_name_pattern"] = []