From b137e0c43f973987d81d5da81841775094dbb96b Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Tue, 5 Aug 2025 12:59:16 -0400 Subject: [PATCH 01/12] feat: fast model inference Signed-off-by: Omobayode Fagbohungbe --- fms_mo/dq.py | 111 ++++++++++++++++++++++++-------------- fms_mo/prep.py | 46 +++++++++++++++- fms_mo/recipes/quant.json | 44 +++++++++++++++ fms_mo/training_args.py | 7 +++ fms_mo/utils/dq_inf.py | 89 ++++++++++++++++++++++++++++++ fms_mo/utils/dq_utils.py | 2 +- 6 files changed, 256 insertions(+), 43 deletions(-) create mode 100644 fms_mo/recipes/quant.json create mode 100644 fms_mo/utils/dq_inf.py diff --git a/fms_mo/dq.py b/fms_mo/dq.py index eb49bc30..aecc8c64 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -50,6 +50,8 @@ from fms_mo.utils.dq_utils import config_quantize_smooth_layers from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU from fms_mo.utils.utils import patch_torch_bmm, prepare_input +from fms_mo.utils.dq_inf import load_fp8_vllm, save_vllm_fp8 +from accelerate import load_checkpoint_and_dispatch logger = logging.getLogger(__name__) @@ -134,7 +136,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Initialized model is: \n {model}") logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") - qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + + if not fms_mo_args.inference or fms_mo_args.vllm_fp8_load: + qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + else: + qcfg = qconfig_init(recipe=opt_args.output_dir+"/qcfg") model_size = model_size_Wb(model, unit="GB") gpu_mem_util_per = model_size / total_gpu_memory @@ -190,7 +196,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - if qcfg["smoothq"]: + if not fms_mo_args.inference and qcfg["smoothq"] : scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") if qcfg.get("act_scale_path", None): # user provided a scale file (or a dir) @@ -224,53 +230,76 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): use_layer_name_pattern_matching=use_layer_name_pattern_matching, use_dynamo=use_dynamo, dev=dev, + mode=fms_mo_args.inference, save_fname="dq", + folder=opt_args.output_dir, ) logger.info(f"Quantized model {model}") logger.info("==" * 20) - if qcfg["smoothq"]: - logger.info("Starting to apply smooth scale") - dq_llm(model, act_scales, qcfg) - logger.info("Finished applying smooth scale") + if not fms_mo_args.inference: + if qcfg["smoothq"]: + logger.info("Starting to apply smooth scale") + dq_llm(model, act_scales, qcfg) + logger.info("Finished applying smooth scale") + + if qcfg["qmodel_calibration_new"] > 0: + logger.info("Starting to calibrate activation clip_val") + if qcfg["large_model"]: + calibration_llm_1GPU_v2(qcfg, model, dq_dataloader) + else: + model.to("cuda") + pbar = tqdm( + dq_dataloader, + desc=" calibration after applying smoothq scale and before inference", + total=qcfg["qmodel_calibration_new"], + ) + for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])): + data_mb = prepare_input(model.device, data_mb) + with patch_torch_bmm(qcfg): + model(**data_mb) + + if opt_args.save_ckpt_for_aiu: + logger.info( + f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}" + ) + save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True) + elif opt_args.save_ckpt_for_vllm: + logger.info( + f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}" + ) + save_vllm_fp8(model,qcfg,tokenizer,opt_args.output_dir) + elif opt_args.save_ckpt: + logger.info( + f"Saving quantized model and tokenizer to {opt_args.output_dir}" + ) + model.save_pretrained(opt_args.output_dir, use_safetensors=True) + tokenizer.save_pretrained(opt_args.output_dir) + + if fms_mo_args.aiu_sim_triton: + # NOTE plz apply correct HW settings here, defaults are not real HW params + lower_qmodel_triton( + model, + use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False, + max_acc_bits=qcfg.get("max_acc_bits", 32), + num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), + chunk_size=qcfg.get("chunk_size", 32), # 1024 + clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", + # layer_to_exclude=["lm_head",] + ) + else: + if fms_mo_args.vllm_fp8_load: + logger.info("loading llmcompressor fp8 model saved_checkpoint") + model = load_fp8_vllm( model=model, checkpoint=opt_args.output_dir) - if qcfg["qmodel_calibration_new"] > 0: - logger.info("Starting to calibrate activation clip_val") - if qcfg["large_model"]: - calibration_llm_1GPU_v2(qcfg, model, dq_dataloader) else: - model.to("cuda") - pbar = tqdm( - dq_dataloader, - desc=" calibration after applying smoothq scale and before inference", - total=qcfg["qmodel_calibration_new"], + logger.info("loading dq fms_mo fp8 model saved_checkpoint") + model = load_checkpoint_and_dispatch( + model, + checkpoint=opt_args.output_dir, + device_map=None, + no_split_module_classes=['Block'] ) - for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])): - data_mb = prepare_input(model.device, data_mb) - with patch_torch_bmm(qcfg): - model(**data_mb) - - if opt_args.save_ckpt_for_aiu: - logger.info( - f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}" - ) - save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True) - elif opt_args.save_ckpt: - logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}") - model.save_pretrained(opt_args.output_dir, use_safetensors=True) - tokenizer.save_pretrained(opt_args.output_dir) - - if fms_mo_args.aiu_sim_triton: - # NOTE plz apply correct HW settings here, defaults are not real HW params - lower_qmodel_triton( - model, - use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False, - max_acc_bits=qcfg.get("max_acc_bits", 32), - num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0), - chunk_size=qcfg.get("chunk_size", 32), # 1024 - clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", - # layer_to_exclude=["lm_head",] - ) if fms_mo_args.eval_ppl: path_test = Path(data_args.test_data_path) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 42e40b79..e097f585 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -570,7 +570,42 @@ def has_quantized_module(model): """Check if model is already quantized - do not want to quantize twice if so""" return any(isinstance(m, quantized_modules) for m in model.modules()) +def swap_qbmm(model: nn.Module, qcfg: dict): + """Go through all model.named_modules(), try to create an equivalent Qbmm layer to replace each of + the existing linear Bmm layers. + Args: + model (nn.Module): input model to be "prepared" + qcfg (dict): quant config + + Returns: updated model is returned with the Qbmm added + + """ + + from fms_mo.modules import QBmm + + qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][ + "which2patch_contextmanager" + ] + isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm" + for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items(): + mod_bmm_happened = model.get_submodule(mod_name) + for whichQBmm, ln in enumerate(line_nums, start=1): + nbits = qcfg[f"nbits_bmm{whichQBmm}"] + newQBmm = QBmm( + num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits, + num_bits_m2=nbits, + qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"], + qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"], + m1_unidirectional=(whichQBmm == 2), + m1_bounded=(whichQBmm == 2), # see Note 5 + m2_unidirectional=False, + m2_bounded=False, + replaceBmm=isbmm, + qcfg=qcfg, + ) + setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm) + def qmodel_prep( model, dloader, @@ -582,7 +617,9 @@ def qmodel_prep( Qcali=False, dev=None, use_dynamo=False, + mode=False, verbose=False, + folder=None, **kwargs, ): """Prepare a given PyTorch model for quantization process through three parts: @@ -657,7 +694,14 @@ def qmodel_prep( Returns: nn.Module: quantized model ready for further PTQ/QAT """ + if mode: + + if qcfg.get("QBmm"): + swap_qbmm(model,qcfg) + model = q_any_net_5(model, qcfg, verbose = False) + return model + sys.setrecursionlimit(4000) currDev = next(model.parameters()).device if dev is None else dev @@ -907,7 +951,7 @@ def qmodel_prep( model, device_ids=DPorDDPdevices ) - qconfig_save(qcfg, fname="qcfg.json") + qconfig_save(qcfg, fname=folder+"/qcfg.json") qcfg["tb_writer"] = tb_writer logger.info(f"--- Quantized model --- \n{model}\n") diff --git a/fms_mo/recipes/quant.json b/fms_mo/recipes/quant.json new file mode 100644 index 00000000..96b87619 --- /dev/null +++ b/fms_mo/recipes/quant.json @@ -0,0 +1,44 @@ +{ +"quantization_config": { + "config_groups": { + "group_0": { + "input_activations": { + "actorder": null, + "block_structure": null, + "dynamic": true, + "group_size": null, + "num_bits": 8, + "observer": null, + "observer_kwargs": {}, + "strategy": "token", + "symmetric": true, + "type": "float" + }, + "output_activations": null, + "targets": [ + "Linear" + ], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": null, + "num_bits": 8, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "channel", + "symmetric": true, + "type": "float" + } + } + }, + "format": "float-quantized", + "global_compression_ratio": null, + "ignore": [ + "lm_head" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed" + } +} \ No newline at end of file diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index 95f38424..a6d6394b 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -160,6 +160,10 @@ class OptArguments(TypeChecker): default=False, metadata={"help": "Prepare and save AIU-compliant checkpoint."}, ) + save_ckpt_for_vllm: bool = field( + default=False, + metadata={"help": "Prepare and save vllm-compliant checkpoint."}, + ) @dataclass @@ -209,6 +213,9 @@ class FMSMOArguments(TypeChecker): default=False, metadata={"help": "Apply recomputation during checkpoint saving for AIU."}, ) + fp8_use_subnormal: bool = field(default=False) + inference: bool = field(default=False) + vllm_fp8_load: bool = field(default=False) @dataclass diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py new file mode 100644 index 00000000..d9fb3928 --- /dev/null +++ b/fms_mo/utils/dq_inf.py @@ -0,0 +1,89 @@ +import torch +import fms_mo +from fms_mo.quant.quantizers import to_fp8_scaled_perCh as fp8 +from huggingface_hub import save_torch_state_dict +import json +import os +import glob +from fms_mo.utils.qconfig_utils import get_recipe +from safetensors.torch import load_file, save_file +from torch import nn + +def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer = None, folder: str = None): + """ + Function to save fms_mo fp8 checkpoint in vllm fp8 format + """ + + st_dict={} + + for k,v in model.state_dict().items(): + if k[-11:] == "proj.weight": + weight, scale = fp8(v,emulate=False) + st_dict[k]= weight + + if k[:-7] in qcfg["qskip_layer_name"]: + pass + else: + st_dict[k + "_scale"] = 1/scale + + elif k[-6:] == "weight": + st_dict[k]=v + else: + pass + + config = model.config.to_dict() + + #TO DO: To support multiple recipes, check qconfig arguments and update data loaded from quant.json + data = get_recipe('quant') + + config.update(data) + + save_torch_state_dict(st_dict, folder) + + tokenizer.save_pretrained(folder) + + with open(folder+'/config.json', 'a') as f: + json.dump(config, f, indent=4) + + + +def find_file_glob(pattern: str , search_path: str): + """ + Finds files matching a pattern within a directory and its subdirectories. + """ + # Use '**' for recursive search in modern Python versions (3.5+) + full_pattern = os.path.join(search_path, '**', pattern) + found_files = glob.glob(full_pattern, recursive=True) + return sorted(found_files) + +def load_fp8_vllm(model: nn.Module = None, checkpoint: str=None): + """ + Function to help load vllm fp8 checkpoint into fms_mo + """ + + merged_files_dict={} + + files = find_file_glob('*.safetensors',checkpoint) + + model_dict = model.state_dict() + + for file in files: + merged_files_dict = load_file(file) + + for k,v in merged_files_dict.items(): + + if k[-11:] == "proj.weight": + scale = merged_files_dict[k+ "_scale"].reshape(-1,1) + model_dict[k]= merged_files_dict[k].to(torch.float16) * scale + + elif k[-6:] == "weight": + model_dict[k]=v + + else: + pass + + return model + + + + diff --git a/fms_mo/utils/dq_utils.py b/fms_mo/utils/dq_utils.py index 2eb51caf..36d2806e 100644 --- a/fms_mo/utils/dq_utils.py +++ b/fms_mo/utils/dq_utils.py @@ -74,7 +74,7 @@ def config_quantize_smooth_layers(qcfg: dict): for llama_family, layers in large_mag_layers.items(): if llama_family in qcfg["model"]: qcfg["qskip_layer_name"] += [ - f"model.layers.{i}.mlp.down_proj" for i in layers + f"model.layers.{i}.mlp.down_projj" for i in layers ] break elif any(model in qcfg["model"] for model in granite_architecture) or any( From 0b5d68ab5e9dd7db947e54c9cc49888b9a913c47 Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Wed, 20 Aug 2025 11:14:25 -0400 Subject: [PATCH 02/12] feat: enable fast loading and vllm format saving functionality in fms_mo Signed-off-by: Omobayode Fagbohungbe --- fms_mo/dq.py | 68 +++++++++------ fms_mo/modules/linear.py | 2 +- fms_mo/prep.py | 25 +++--- fms_mo/quant/quantizers.py | 3 +- fms_mo/training_args.py | 7 +- fms_mo/utils/dq_inf.py | 165 +++++++++++++++++++++++++------------ 6 files changed, 172 insertions(+), 98 deletions(-) diff --git a/fms_mo/dq.py b/fms_mo/dq.py index aecc8c64..72094173 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -1,11 +1,11 @@ # Copyright The FMS Model Optimizer Authors -# + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -34,6 +34,7 @@ ) import torch +import os # Local from fms_mo import qconfig_init, qmodel_prep from fms_mo.custom_ext_kernels.utils import ( @@ -50,8 +51,11 @@ from fms_mo.utils.dq_utils import config_quantize_smooth_layers from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU from fms_mo.utils.utils import patch_torch_bmm, prepare_input -from fms_mo.utils.dq_inf import load_fp8_vllm, save_vllm_fp8 -from accelerate import load_checkpoint_and_dispatch +from fms_mo.utils.dq_inf import ( + save_vllm_fp8, + convert_fp8_vllm_to_fms_mo, + check_quantization_setting, +) logger = logging.getLogger(__name__) @@ -129,6 +133,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): low_cpu_mem_usage=bool(model_args.device_map), ) + inference= model.config.to_dict().get("quantization_config",None) + + if inference: + quant_setting = check_quantization_setting(inference) + if quant_setting: + logger.info("Quantization config settings validated ") + model = convert_fp8_vllm_to_fms_mo(model = model) + else: + exit("__This quantization config is wrong/not supported__") + + embedding_size = model.get_input_embeddings().weight.shape[0] if len(tokenizer) > embedding_size: model.resize_token_embeddings(len(tokenizer)) @@ -136,11 +151,24 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Initialized model is: \n {model}") logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") - - if not fms_mo_args.inference or fms_mo_args.vllm_fp8_load: + + if not inference: + logger.info("quantization mode activated, initalizing the qcfg file ") qcfg = qconfig_init(recipe="dq", args=fms_mo_args) else: - qcfg = qconfig_init(recipe=opt_args.output_dir+"/qcfg") + logger.info("inference mode activated") + if os.path.isfile(model_args.model_name_or_path+"/qcfg.json"): + if fms_mo_args.override_fms_args: + logger.info("qcfg file found and some parameters are being over-written ") + qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg", args=fms_mo_args) + else: + logger.info("qcfg file found, loading the qcfg file ") + qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg") + else: + logger.info("qcfg file not found in {model_args.model_name_or_path},\ + loading fms_mo_args and recipe" + ) + qcfg = qconfig_init(recipe="dq", args=fms_mo_args) model_size = model_size_Wb(model, unit="GB") gpu_mem_util_per = model_size / total_gpu_memory @@ -184,6 +212,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["model"] = model_args.model_name_or_path qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg qcfg["plotsvg"] = False + qcfg["output_folder"] = opt_args.output_dir calibration_dataset = load_from_disk(data_args.training_data_path) calibration_dataset = calibration_dataset.with_format("torch") @@ -196,7 +225,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - if not fms_mo_args.inference and qcfg["smoothq"] : + if not inference and qcfg["smoothq"] : scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") if qcfg.get("act_scale_path", None): # user provided a scale file (or a dir) @@ -230,14 +259,12 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): use_layer_name_pattern_matching=use_layer_name_pattern_matching, use_dynamo=use_dynamo, dev=dev, - mode=fms_mo_args.inference, + mode=inference, save_fname="dq", - folder=opt_args.output_dir, ) logger.info(f"Quantized model {model}") logger.info("==" * 20) - - if not fms_mo_args.inference: + if not inference: if qcfg["smoothq"]: logger.info("Starting to apply smooth scale") dq_llm(model, act_scales, qcfg) @@ -264,7 +291,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}" ) save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True) - elif opt_args.save_ckpt_for_vllm: + elif not opt_args.save_ckpt: logger.info( f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}" ) @@ -287,19 +314,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8", # layer_to_exclude=["lm_head",] ) - else: - if fms_mo_args.vllm_fp8_load: - logger.info("loading llmcompressor fp8 model saved_checkpoint") - model = load_fp8_vllm( model=model, checkpoint=opt_args.output_dir) - - else: - logger.info("loading dq fms_mo fp8 model saved_checkpoint") - model = load_checkpoint_and_dispatch( - model, - checkpoint=opt_args.output_dir, - device_map=None, - no_split_module_classes=['Block'] - ) if fms_mo_args.eval_ppl: path_test = Path(data_args.test_data_path) diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 3a39bb30..ee5d8202 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -281,6 +281,7 @@ def forward(self, x): ) # pylint: disable=not-callable + return F.linear(x, self.W_fp, self.bias) else: qinput = self.quantize_feature(x / scale).to(x.dtype) @@ -296,7 +297,6 @@ def forward(self, x): ) qbias = self.bias - # pylint: disable=not-callable output = F.linear(qinput, qweight, qbias) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index e097f585..62171a19 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -23,7 +23,7 @@ # Third Party from torch import nn import torch - +import compressed_tensors # Local from fms_mo.calib import qmodel_calib from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules @@ -391,12 +391,14 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): # For nn.Linear elif isinstance(module, nn.Linear): if module.__class__ != nn.Linear: - logger.warning( - f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." - "Please make sure it doesn't wrap BN and activ func." - "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." - ) - + if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear): + pass + else: + logger.warning( + f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." + "Please make sure it doesn't wrap BN and activ func." + "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." + ) QLin = mapping.get(nn.Linear, None) if QLin is None: if verbose: @@ -571,8 +573,8 @@ def has_quantized_module(model): return any(isinstance(m, quantized_modules) for m in model.modules()) def swap_qbmm(model: nn.Module, qcfg: dict): - """Go through all model.named_modules(), try to create an equivalent Qbmm layer to replace each of - the existing linear Bmm layers. + """Go through all model.named_modules(), try to create an equivalent + Qbmm layer to replace each of the existing linear Bmm layers. Args: model (nn.Module): input model to be "prepared" @@ -605,7 +607,7 @@ def swap_qbmm(model: nn.Module, qcfg: dict): qcfg=qcfg, ) setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm) - + def qmodel_prep( model, dloader, @@ -619,7 +621,6 @@ def qmodel_prep( use_dynamo=False, mode=False, verbose=False, - folder=None, **kwargs, ): """Prepare a given PyTorch model for quantization process through three parts: @@ -951,7 +952,7 @@ def qmodel_prep( model, device_ids=DPorDDPdevices ) - qconfig_save(qcfg, fname=folder+"/qcfg.json") + qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json") qcfg["tb_writer"] = tb_writer logger.info(f"--- Quantized model --- \n{model}\n") diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index 405ddc63..0ce66a14 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -237,6 +237,7 @@ def get_weight_quantizer( recompute=False, perGp=None, use_subnormal=False, + emulate = True, ): """Return a quantizer for weight quantization Regular quantizers: @@ -346,7 +347,7 @@ def get_weight_quantizer( weight_quantizer = to_fp8( nbits, q_mode=qw_mode, - emulate=True, + emulate=emulate, perCh=Nch, ) else: diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index a6d6394b..66d230fb 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -160,10 +160,6 @@ class OptArguments(TypeChecker): default=False, metadata={"help": "Prepare and save AIU-compliant checkpoint."}, ) - save_ckpt_for_vllm: bool = field( - default=False, - metadata={"help": "Prepare and save vllm-compliant checkpoint."}, - ) @dataclass @@ -214,8 +210,7 @@ class FMSMOArguments(TypeChecker): metadata={"help": "Apply recomputation during checkpoint saving for AIU."}, ) fp8_use_subnormal: bool = field(default=False) - inference: bool = field(default=False) - vllm_fp8_load: bool = field(default=False) + override_fms_args: bool = field(default=False) @dataclass diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index d9fb3928..c7639611 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -1,89 +1,152 @@ -import torch -import fms_mo -from fms_mo.quant.quantizers import to_fp8_scaled_perCh as fp8 from huggingface_hub import save_torch_state_dict import json import os import glob -from fms_mo.utils.qconfig_utils import get_recipe -from safetensors.torch import load_file, save_file +from safetensors.torch import load_file from torch import nn +import torch +from fms_mo.utils.qconfig_utils import get_recipe +from fms_mo.quant.quantizers import to_fp8_scaled_perCh -def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer = None, folder: str = None): +def check_quantization_setting(inference :dict = None): """ - Function to save fms_mo fp8 checkpoint in vllm fp8 format + function checks if the checkpoint is from fp8 quantization """ + status= False + if inference["config_groups"]["group_0"]["input_activations"]["num_bits"]== 8 \ + and inference["config_groups"]["group_0"]["weights"]["num_bits"] == 8 \ + and inference["config_groups"]["group_0"]["weights"]["type"] == "float" \ + and inference["config_groups"]["group_0"]["input_activations"]["type"] == "float": - st_dict={} - - for k,v in model.state_dict().items(): - if k[-11:] == "proj.weight": - weight, scale = fp8(v,emulate=False) - st_dict[k]= weight + status = True + return status - if k[:-7] in qcfg["qskip_layer_name"]: - pass +#def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None): +def rename_fms_dict_to_vllm_dict (model_dict : dict= None): + """ + Function to rename the dict in fms_mo format to vllm_format. + """ + st_dict={} + fms_dict={} + keys = model_dict.keys() + print(keys) + count=0 + for k,v in model_dict.items(): + if ".weight" in k: + count+=1 + key= k.split("weight")[0] + if key+"quantize_weight.scale" in keys: + weight, scale = to_fp8_scaled_perCh(v,emulate=False) + st_dict[key+"weight"]= weight + st_dict[key + "weight_scale"] = 1/scale else: - st_dict[k + "_scale"] = 1/scale - - elif k[-6:] == "weight": - st_dict[k]=v + st_dict[k]= v else: - pass + fms_dict[k] = v + return st_dict, fms_dict - config = model.config.to_dict() +def update_config(model_config_file : dict = None, qcfg : dict = None): + """ + Function to update the model config file with quantization configuration + """ + data = get_recipe("quant") + if "perCh" not in qcfg["qw_mode"]: + data["quantization_config"]["config_groups"]["group_0"]["weights"] = \ + "{num_bits: 8, type: float, symmetric: true, strategy: tensor}" - #TO DO: To support multiple recipes, check qconfig arguments and update data loaded from quant.json - data = get_recipe('quant') - - config.update(data) + model_config_file.update(data) + return model_config_file - save_torch_state_dict(st_dict, folder) +def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer = None, folder: str = None): + """ + Function to save fp8 DQ model in vllm fp8 format + """ + model_dict = model.state_dict() + vllm_dict, fms_dict = rename_fms_dict_to_vllm_dict(model_dict=model_dict) + config = model.config.to_dict() + config = update_config( config, qcfg) + save_torch_state_dict(vllm_dict, folder) + save_torch_state_dict(fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors") tokenizer.save_pretrained(folder) - with open(folder+'/config.json', 'a') as f: + with open(folder+"/config.json", "w+") as f: json.dump(config, f, indent=4) +def convert_fms_mo_to_vllm_fp8_format(checkpoint : str = None, folder: str = None): + """ + Function to convert fp8 fms_mo DQ model checkpoint to vllm fp8 format + """ + folder = checkpoint+"/" + folder + if os.path.isdir(folder): + print(f"The folder '{folder}' exists.") + else: + os.mkdir(folder) + print(f"The folder '{folder}' created.") + + qcfg = get_recipe(checkpoint+"/qcfg") + config = get_recipe(checkpoint+"/config") + files = find_file_glob("model-*",checkpoint) + merged_files_dict={} + + for file in files: + temp_dict = load_file(file) + merged_files_dict.update(temp_dict) + vllm_dict, fms_dict = rename_fms_dict_to_vllm_dict(merged_files_dict) + config = update_config(config, qcfg) + + save_torch_state_dict(vllm_dict, folder) + save_torch_state_dict(fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors") + with open(folder+"/config.json", "w+") as f: + json.dump(config, f, indent=4) def find_file_glob(pattern: str , search_path: str): """ Finds files matching a pattern within a directory and its subdirectories. """ # Use '**' for recursive search in modern Python versions (3.5+) - full_pattern = os.path.join(search_path, '**', pattern) + full_pattern = os.path.join(search_path, "**", pattern) found_files = glob.glob(full_pattern, recursive=True) return sorted(found_files) -def load_fp8_vllm(model: nn.Module = None, checkpoint: str=None): +def convert_fp8_vllm_dict_to_fms_mo_dict(checkpoint: str=None, output_dir : str=None): + """ + Function to help convert vllm fp8 checkpoint into fms_mo fp8 format """ - Function to help load vllm fp8 checkpoint into fms_mo - """ - merged_files_dict={} - - files = find_file_glob('*.safetensors',checkpoint) - - model_dict = model.state_dict() - + files = find_file_glob("model-*",checkpoint) for file in files: - merged_files_dict = load_file(file) - - for k,v in merged_files_dict.items(): + temp_dict = load_file(file) + merged_files_dict.update(temp_dict) - if k[-11:] == "proj.weight": - scale = merged_files_dict[k+ "_scale"].reshape(-1,1) - model_dict[k]= merged_files_dict[k].to(torch.float16) * scale + fms_mo_dict = rename_vllm_dict_to_fms_mo(merged_files_dict) + save_torch_state_dict(fms_mo_dict, output_dir) - elif k[-6:] == "weight": - model_dict[k]=v - - else: +def rename_vllm_dict_to_fms_mo(vllm_dict : dict = None): + """ + Function to help rename vllm dict format to fms_mo dict format + """ + fms_mo_dict ={} + for k,v in vllm_dict.items(): + if "weight_scale" in k: + key = k.split("weight")[0] + fms_mo_dict[key+"weight"]=vllm_dict[key+"weight"].to(torch.float16) *v + fms_mo_dict[k]= v + else: + key = k.split("weight")[0] + if key+"weight_scale" in vllm_dict.keys(): pass + else: + fms_mo_dict[k]= v + return fms_mo_dict +def convert_fp8_vllm_to_fms_mo(model: nn.Module = None): + """ + Function to help convert fp8 vllm model dict format to fms_mo fp8 format + """ + model_dict = model.state_dict() + fms_dict = rename_vllm_dict_to_fms_mo(model_dict) + model = model.to(torch.float16) + model.load_state_dict(fms_dict) return model - - - - From b458d18df344eeb899c3aeecb4c64c1dd6375ef5 Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Fri, 22 Aug 2025 00:09:54 -0400 Subject: [PATCH 03/12] fix: updated the code to reflect PR update Signed-off-by: Omobayode Fagbohungbe --- .pylintrc | 3 +- fms_mo/dq.py | 64 ++++---- fms_mo/modules/linear.py | 2 +- fms_mo/prep.py | 49 +++--- fms_mo/quant/quantizers.py | 2 +- fms_mo/recipes/dq.json | 4 +- ...json => fp8_vllm_quantization_config.json} | 0 fms_mo/utils/dq_inf.py | 153 +++++++++++------- fms_mo/utils/dq_utils.py | 2 +- fms_mo/utils/import_utils.py | 1 + fms_mo/utils/qconfig_utils.py | 4 + pyproject.toml | 2 +- 12 files changed, 175 insertions(+), 111 deletions(-) rename fms_mo/recipes/{quant.json => fp8_vllm_quantization_config.json} (100%) diff --git a/.pylintrc b/.pylintrc index 4effcbf7..da95ef1d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -69,7 +69,8 @@ ignored-modules=gptqmodel, llmcompressor, cutlass_mm, pygraphviz, - matplotlib + matplotlib, + compressed_tensors # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). diff --git a/fms_mo/dq.py b/fms_mo/dq.py index 72094173..6f2c0e7b 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -1,11 +1,11 @@ # Copyright The FMS Model Optimizer Authors - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,6 +21,8 @@ # Standard from pathlib import Path import logging +import os +import sys # Third Party from datasets import load_from_disk @@ -34,7 +36,6 @@ ) import torch -import os # Local from fms_mo import qconfig_init, qmodel_prep from fms_mo.custom_ext_kernels.utils import ( @@ -48,14 +49,14 @@ get_act_scales_1gpu, ) from fms_mo.utils.aiu_utils import save_for_aiu -from fms_mo.utils.dq_utils import config_quantize_smooth_layers -from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU -from fms_mo.utils.utils import patch_torch_bmm, prepare_input from fms_mo.utils.dq_inf import ( - save_vllm_fp8, - convert_fp8_vllm_to_fms_mo, check_quantization_setting, + convert_fp8_vllm_to_fms_mo, + save_vllm_fp8, ) +from fms_mo.utils.dq_utils import config_quantize_smooth_layers +from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU +from fms_mo.utils.utils import patch_torch_bmm, prepare_input logger = logging.getLogger(__name__) @@ -133,16 +134,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): low_cpu_mem_usage=bool(model_args.device_map), ) - inference= model.config.to_dict().get("quantization_config",None) + inference_qconfig = None + if hasattr(model, "config"): + inference_qconfig = model.config.to_dict().get("quantization_config", None) - if inference: - quant_setting = check_quantization_setting(inference) + if inference_qconfig: + quant_setting = check_quantization_setting(inference_qconfig) if quant_setting: logger.info("Quantization config settings validated ") - model = convert_fp8_vllm_to_fms_mo(model = model) + model = convert_fp8_vllm_to_fms_mo(model=model) else: - exit("__This quantization config is wrong/not supported__") - + sys.exit("Error: This quantization config is wrong/not supported") embedding_size = model.get_input_embeddings().weight.shape[0] if len(tokenizer) > embedding_size: @@ -152,23 +154,29 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") - if not inference: + if not inference_qconfig: logger.info("quantization mode activated, initalizing the qcfg file ") qcfg = qconfig_init(recipe="dq", args=fms_mo_args) else: logger.info("inference mode activated") - if os.path.isfile(model_args.model_name_or_path+"/qcfg.json"): + if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"): if fms_mo_args.override_fms_args: - logger.info("qcfg file found and some parameters are being over-written ") - qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg", args=fms_mo_args) + logger.info( + "qcfg file found and some parameters are being over-written " + ) + qcfg = qconfig_init( + recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args + ) else: logger.info("qcfg file found, loading the qcfg file ") - qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg") + qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg") else: - logger.info("qcfg file not found in {model_args.model_name_or_path},\ + logger.info( + "qcfg file not found in {model_args.model_name_or_path},\ loading fms_mo_args and recipe" - ) + ) qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + qcfg["inference"] = True model_size = model_size_Wb(model, unit="GB") gpu_mem_util_per = model_size / total_gpu_memory @@ -193,7 +201,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["model"] = model_args.model_name_or_path # config layers to skip, smooth scale - config_quantize_smooth_layers(qcfg) + if not inference_qconfig: + config_quantize_smooth_layers(qcfg) use_dynamo = True # use dynamo as default unless really needed, False -> fallback to TorchScript tracing @@ -225,7 +234,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - if not inference and qcfg["smoothq"] : + if not inference_qconfig and qcfg["smoothq"]: scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") if qcfg.get("act_scale_path", None): # user provided a scale file (or a dir) @@ -259,12 +268,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): use_layer_name_pattern_matching=use_layer_name_pattern_matching, use_dynamo=use_dynamo, dev=dev, - mode=inference, save_fname="dq", ) logger.info(f"Quantized model {model}") logger.info("==" * 20) - if not inference: + if not inference_qconfig: if qcfg["smoothq"]: logger.info("Starting to apply smooth scale") dq_llm(model, act_scales, qcfg) @@ -295,11 +303,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info( f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}" ) - save_vllm_fp8(model,qcfg,tokenizer,opt_args.output_dir) + save_vllm_fp8(model, qcfg, tokenizer, opt_args.output_dir) elif opt_args.save_ckpt: logger.info( f"Saving quantized model and tokenizer to {opt_args.output_dir}" - ) + ) model.save_pretrained(opt_args.output_dir, use_safetensors=True) tokenizer.save_pretrained(opt_args.output_dir) diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index ee5d8202..3a39bb30 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -281,7 +281,6 @@ def forward(self, x): ) # pylint: disable=not-callable - return F.linear(x, self.W_fp, self.bias) else: qinput = self.quantize_feature(x / scale).to(x.dtype) @@ -297,6 +296,7 @@ def forward(self, x): ) qbias = self.bias + # pylint: disable=not-callable output = F.linear(qinput, qweight, qbias) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 62171a19..d5a4554b 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -23,7 +23,7 @@ # Third Party from torch import nn import torch -import compressed_tensors + # Local from fms_mo.calib import qmodel_calib from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules @@ -391,13 +391,19 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): # For nn.Linear elif isinstance(module, nn.Linear): if module.__class__ != nn.Linear: - if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear): - pass - else: - logger.warning( - f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." - "Please make sure it doesn't wrap BN and activ func." - "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." + if available_packages["compressed_tensors"]: + # Third Party + import compressed_tensors + + if isinstance( + module, compressed_tensors.linear.compressed_linear.CompressedLinear + ): + pass + else: + logger.warning( + f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." + "Please make sure it doesn't wrap BN and activ func. Otherwise" + "please create an equivalen Linear wrapper and change qcfg['mapping']." ) QLin = mapping.get(nn.Linear, None) if QLin is None: @@ -572,6 +578,7 @@ def has_quantized_module(model): """Check if model is already quantized - do not want to quantize twice if so""" return any(isinstance(m, quantized_modules) for m in model.modules()) + def swap_qbmm(model: nn.Module, qcfg: dict): """Go through all model.named_modules(), try to create an equivalent Qbmm layer to replace each of the existing linear Bmm layers. @@ -581,14 +588,13 @@ def swap_qbmm(model: nn.Module, qcfg: dict): qcfg (dict): quant config Returns: updated model is returned with the Qbmm added - + """ + # Local from fms_mo.modules import QBmm - qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][ - "which2patch_contextmanager" - ] + qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"]["which2patch_contextmanager"] isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm" for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items(): mod_bmm_happened = model.get_submodule(mod_name) @@ -608,6 +614,7 @@ def swap_qbmm(model: nn.Module, qcfg: dict): ) setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm) + def qmodel_prep( model, dloader, @@ -619,7 +626,6 @@ def qmodel_prep( Qcali=False, dev=None, use_dynamo=False, - mode=False, verbose=False, **kwargs, ): @@ -695,14 +701,13 @@ def qmodel_prep( Returns: nn.Module: quantized model ready for further PTQ/QAT """ - if mode: - - if qcfg.get("QBmm"): - swap_qbmm(model,qcfg) + if qcfg["inference"]: + if qcfg.get("QBmm"): + swap_qbmm(model, qcfg) - model = q_any_net_5(model, qcfg, verbose = False) + model = q_any_net_5(model, qcfg, verbose=False) return model - + sys.setrecursionlimit(4000) currDev = next(model.parameters()).device if dev is None else dev @@ -951,8 +956,10 @@ def qmodel_prep( model = torch.nn.parallel.DistributedDataParallel( model, device_ids=DPorDDPdevices ) - - qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json") + if qcfg["output_folder"] is None: + qconfig_save(qcfg, fname="qcfg.json") + else: + qconfig_save(qcfg, fname=qcfg["output_folder"] + "/qcfg.json") qcfg["tb_writer"] = tb_writer logger.info(f"--- Quantized model --- \n{model}\n") diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index 0ce66a14..371632c6 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -237,7 +237,7 @@ def get_weight_quantizer( recompute=False, perGp=None, use_subnormal=False, - emulate = True, + emulate=True, ): """Return a quantizer for weight quantization Regular quantizers: diff --git a/fms_mo/recipes/dq.json b/fms_mo/recipes/dq.json index be425998..cee7c505 100644 --- a/fms_mo/recipes/dq.json +++ b/fms_mo/recipes/dq.json @@ -10,5 +10,7 @@ "eval_ckpt": true, "nbits_bmm1" : 32, "nbits_bmm2" : 32, - "nbits_kvcache" : 32 + "nbits_kvcache" : 32, + "inference": false, + "output_folder": null } \ No newline at end of file diff --git a/fms_mo/recipes/quant.json b/fms_mo/recipes/fp8_vllm_quantization_config.json similarity index 100% rename from fms_mo/recipes/quant.json rename to fms_mo/recipes/fp8_vllm_quantization_config.json diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index c7639611..e56eb285 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -1,93 +1,124 @@ -from huggingface_hub import save_torch_state_dict -import json -import os +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluation utils for fast model loading and saving for FP8 DQ +""" + +# Standard import glob +import json +import logging +import os + +# Third Party +from huggingface_hub import save_torch_state_dict from safetensors.torch import load_file from torch import nn import torch -from fms_mo.utils.qconfig_utils import get_recipe + +# Local from fms_mo.quant.quantizers import to_fp8_scaled_perCh +from fms_mo.utils.qconfig_utils import get_recipe + +logger = logging.getLogger(__name__) -def check_quantization_setting(inference :dict = None): + +def check_quantization_setting(inference: dict = None): """ function checks if the checkpoint is from fp8 quantization """ - status= False - if inference["config_groups"]["group_0"]["input_activations"]["num_bits"]== 8 \ - and inference["config_groups"]["group_0"]["weights"]["num_bits"] == 8 \ - and inference["config_groups"]["group_0"]["weights"]["type"] == "float" \ - and inference["config_groups"]["group_0"]["input_activations"]["type"] == "float": + return ( + inference["config_groups"]["group_0"]["input_activations"]["num_bits"] == 8 + and inference["config_groups"]["group_0"]["weights"]["num_bits"] == 8 + and inference["config_groups"]["group_0"]["weights"]["type"] == "float" + and inference["config_groups"]["group_0"]["input_activations"]["type"] + == "float" + ) - status = True - return status -#def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None): -def rename_fms_dict_to_vllm_dict (model_dict : dict= None): +# def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None): +def rename_fms_dict_to_vllm_dict(model_dict: dict = None): """ Function to rename the dict in fms_mo format to vllm_format. """ - st_dict={} - fms_dict={} + st_dict = {} + fms_dict = {} keys = model_dict.keys() - print(keys) - count=0 - for k,v in model_dict.items(): + + for k, v in model_dict.items(): if ".weight" in k: - count+=1 - key= k.split("weight")[0] - if key+"quantize_weight.scale" in keys: - weight, scale = to_fp8_scaled_perCh(v,emulate=False) - st_dict[key+"weight"]= weight - st_dict[key + "weight_scale"] = 1/scale + key = k.split("weight")[0] + if key + "quantize_weight.scale" in keys: + weight, scale = to_fp8_scaled_perCh(v, emulate=False) + st_dict[key + "weight"] = weight + st_dict[key + "weight_scale"] = 1 / scale else: - st_dict[k]= v + st_dict[k] = v else: fms_dict[k] = v return st_dict, fms_dict -def update_config(model_config_file : dict = None, qcfg : dict = None): + +def update_config(model_config_file: dict = None, qcfg: dict = None): """ Function to update the model config file with quantization configuration """ - data = get_recipe("quant") + data = get_recipe("fp8_vllm_quantization_config") if "perCh" not in qcfg["qw_mode"]: - data["quantization_config"]["config_groups"]["group_0"]["weights"] = \ - "{num_bits: 8, type: float, symmetric: true, strategy: tensor}" + data["quantization_config"]["config_groups"]["group_0"]["weights"] = ( + "{num_bits: 8, type: float, symmetric: true, strategy: tensor}" + ) model_config_file.update(data) return model_config_file -def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer = None, folder: str = None): + +def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer=None, folder: str = None): """ Function to save fp8 DQ model in vllm fp8 format """ model_dict = model.state_dict() vllm_dict, fms_dict = rename_fms_dict_to_vllm_dict(model_dict=model_dict) config = model.config.to_dict() - config = update_config( config, qcfg) + config = update_config(config, qcfg) save_torch_state_dict(vllm_dict, folder) - save_torch_state_dict(fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors") + save_torch_state_dict( + fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors" + ) tokenizer.save_pretrained(folder) - with open(folder+"/config.json", "w+") as f: + with open(folder + "/config.json", "w+", encoding="utf-8") as f: json.dump(config, f, indent=4) - -def convert_fms_mo_to_vllm_fp8_format(checkpoint : str = None, folder: str = None): + + +def convert_fms_mo_to_vllm_fp8_format(checkpoint: str = None, folder: str = None): """ Function to convert fp8 fms_mo DQ model checkpoint to vllm fp8 format """ - folder = checkpoint+"/" + folder + folder = checkpoint + "/" + folder if os.path.isdir(folder): - print(f"The folder '{folder}' exists.") + logger(f"The folder '{folder}' exists.") else: os.mkdir(folder) - print(f"The folder '{folder}' created.") + logger(f"The folder '{folder}' created.") - qcfg = get_recipe(checkpoint+"/qcfg") - config = get_recipe(checkpoint+"/config") - files = find_file_glob("model-*",checkpoint) - merged_files_dict={} + qcfg = get_recipe(checkpoint + "/qcfg") + config = get_recipe(checkpoint + "/config") + files = find_file_glob("model-*", checkpoint) + merged_files_dict = {} for file in files: temp_dict = load_file(file) @@ -97,11 +128,14 @@ def convert_fms_mo_to_vllm_fp8_format(checkpoint : str = None, folder: str = Non config = update_config(config, qcfg) save_torch_state_dict(vllm_dict, folder) - save_torch_state_dict(fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors") - with open(folder+"/config.json", "w+") as f: + save_torch_state_dict( + fms_dict, folder, filename_pattern="fms_mo{suffix}.safetensors" + ) + with open(folder + "/config.json", "w+", encoding="utf-8") as f: json.dump(config, f, indent=4) -def find_file_glob(pattern: str , search_path: str): + +def find_file_glob(pattern: str, search_path: str): """ Finds files matching a pattern within a directory and its subdirectories. """ @@ -110,12 +144,15 @@ def find_file_glob(pattern: str , search_path: str): found_files = glob.glob(full_pattern, recursive=True) return sorted(found_files) -def convert_fp8_vllm_dict_to_fms_mo_dict(checkpoint: str=None, output_dir : str=None): + +def convert_fp8_vllm_dict_to_fms_mo_dict( + checkpoint: str = None, output_dir: str = None +): """ Function to help convert vllm fp8 checkpoint into fms_mo fp8 format """ - merged_files_dict={} - files = find_file_glob("model-*",checkpoint) + merged_files_dict = {} + files = find_file_glob("model-*", checkpoint) for file in files: temp_dict = load_file(file) merged_files_dict.update(temp_dict) @@ -123,24 +160,28 @@ def convert_fp8_vllm_dict_to_fms_mo_dict(checkpoint: str=None, output_dir : str= fms_mo_dict = rename_vllm_dict_to_fms_mo(merged_files_dict) save_torch_state_dict(fms_mo_dict, output_dir) -def rename_vllm_dict_to_fms_mo(vllm_dict : dict = None): + +def rename_vllm_dict_to_fms_mo(vllm_dict: dict = None): """ Function to help rename vllm dict format to fms_mo dict format """ - fms_mo_dict ={} - for k,v in vllm_dict.items(): + fms_mo_dict = {} + for k, v in vllm_dict.items(): if "weight_scale" in k: key = k.split("weight")[0] - fms_mo_dict[key+"weight"]=vllm_dict[key+"weight"].to(torch.float16) *v - fms_mo_dict[k]= v + fms_mo_dict[key + "weight"] = ( + vllm_dict[key + "weight"].to(torch.float16) * v + ) + fms_mo_dict[k] = v else: key = k.split("weight")[0] - if key+"weight_scale" in vllm_dict.keys(): + if key + "weight_scale" in vllm_dict.keys(): pass else: - fms_mo_dict[k]= v + fms_mo_dict[k] = v return fms_mo_dict + def convert_fp8_vllm_to_fms_mo(model: nn.Module = None): """ Function to help convert fp8 vllm model dict format to fms_mo fp8 format diff --git a/fms_mo/utils/dq_utils.py b/fms_mo/utils/dq_utils.py index 36d2806e..2eb51caf 100644 --- a/fms_mo/utils/dq_utils.py +++ b/fms_mo/utils/dq_utils.py @@ -74,7 +74,7 @@ def config_quantize_smooth_layers(qcfg: dict): for llama_family, layers in large_mag_layers.items(): if llama_family in qcfg["model"]: qcfg["qskip_layer_name"] += [ - f"model.layers.{i}.mlp.down_projj" for i in layers + f"model.layers.{i}.mlp.down_proj" for i in layers ] break elif any(model in qcfg["model"] for model in granite_architecture) or any( diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index a695e39d..6298c353 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -42,6 +42,7 @@ "torchvision", "huggingface_hub", "torchao", + "compressed_tensors", ] available_packages = {} diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index c8e9a093..d05c3579 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -88,6 +88,8 @@ def config_defaults() -> dict: "nbits_w_lstm": None, "nbits_i_lstm": None, "nbits_h_lstm": None, + "inference": False, + "output_folder": None, # qmodes vars "qa_mode": "pact+", "qw_mode": "sawb+", @@ -299,6 +301,8 @@ def qconfig_init(recipe: str = None, args: Any = None, use_mx: bool = False) -> qcfg["w_init_method"] = "sawb" qcfg["a_init_method"] = "percentile" qcfg["clip_val_asst_percentile"] = (0.1, 99.9) + qcfg["inference"] = False + qcfg["output_folder"] = None # ways to control which layers to be quantized/skipped qcfg["qlayer_name_pattern"] = [] diff --git a/pyproject.toml b/pyproject.toml index abd8fd1a..5ed690e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] -fp8 = ["llmcompressor", "torchao==0.11"] +fp8 = ["llmcompressor", "torchao==0.11", "compressed_tensors"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] From adb7f3838592ee89b87b478e9336fa9bff08bdec Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Wed, 27 Aug 2025 11:40:00 -0400 Subject: [PATCH 04/12] fix: re-naming of qcfg inference parameter Signed-off-by: Omobayode Fagbohungbe --- fms_mo/dq.py | 2 +- fms_mo/prep.py | 2 +- fms_mo/recipes/dq.json | 2 +- fms_mo/utils/qconfig_utils.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fms_mo/dq.py b/fms_mo/dq.py index 6f2c0e7b..ef71203b 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -176,7 +176,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): loading fms_mo_args and recipe" ) qcfg = qconfig_init(recipe="dq", args=fms_mo_args) - qcfg["inference"] = True + qcfg["fp8_inference"] = True model_size = model_size_Wb(model, unit="GB") gpu_mem_util_per = model_size / total_gpu_memory diff --git a/fms_mo/prep.py b/fms_mo/prep.py index d5a4554b..b56525ea 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -701,7 +701,7 @@ def qmodel_prep( Returns: nn.Module: quantized model ready for further PTQ/QAT """ - if qcfg["inference"]: + if qcfg["fp8_inference"]: if qcfg.get("QBmm"): swap_qbmm(model, qcfg) diff --git a/fms_mo/recipes/dq.json b/fms_mo/recipes/dq.json index cee7c505..70c7a87d 100644 --- a/fms_mo/recipes/dq.json +++ b/fms_mo/recipes/dq.json @@ -11,6 +11,6 @@ "nbits_bmm1" : 32, "nbits_bmm2" : 32, "nbits_kvcache" : 32, - "inference": false, + "fp8_inference": false, "output_folder": null } \ No newline at end of file diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index d05c3579..b479d302 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -88,8 +88,6 @@ def config_defaults() -> dict: "nbits_w_lstm": None, "nbits_i_lstm": None, "nbits_h_lstm": None, - "inference": False, - "output_folder": None, # qmodes vars "qa_mode": "pact+", "qw_mode": "sawb+", @@ -152,6 +150,8 @@ def config_defaults() -> dict: "smoothq_scale_layers": [], "smoothq_act_scale_path": None, # Other vars + "fp8_inference": False, + "output_folder": None, "which2patch_contextmanager": None, "force_stop_if_qbmm_auto_check_failed": False, "world_size": max(1, torch.cuda.device_count()), @@ -301,7 +301,7 @@ def qconfig_init(recipe: str = None, args: Any = None, use_mx: bool = False) -> qcfg["w_init_method"] = "sawb" qcfg["a_init_method"] = "percentile" qcfg["clip_val_asst_percentile"] = (0.1, 99.9) - qcfg["inference"] = False + qcfg["fp8_inference"] = False qcfg["output_folder"] = None # ways to control which layers to be quantized/skipped From 31dd8c7ed0480b27f7a3a74afd9fdf959c2c128a Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Tue, 2 Sep 2025 20:25:14 -0400 Subject: [PATCH 05/12] fix: updated the inference file Signed-off-by: Omobayode Fagbohungbe --- fms_mo/dq.py | 48 ++++------------ fms_mo/prep.py | 20 +++---- fms_mo/quant/quantizers.py | 3 +- fms_mo/utils/dq_inf.py | 112 ++++++++++++++++++++++++++++++++++--- pyproject.toml | 2 +- 5 files changed, 128 insertions(+), 57 deletions(-) diff --git a/fms_mo/dq.py b/fms_mo/dq.py index ef71203b..6c28fea0 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -21,8 +21,6 @@ # Standard from pathlib import Path import logging -import os -import sys # Third Party from datasets import load_from_disk @@ -52,6 +50,7 @@ from fms_mo.utils.dq_inf import ( check_quantization_setting, convert_fp8_vllm_to_fms_mo, + load_inference_qconfig_file, save_vllm_fp8, ) from fms_mo.utils.dq_utils import config_quantize_smooth_layers @@ -134,18 +133,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): low_cpu_mem_usage=bool(model_args.device_map), ) - inference_qconfig = None - if hasattr(model, "config"): - inference_qconfig = model.config.to_dict().get("quantization_config", None) - - if inference_qconfig: - quant_setting = check_quantization_setting(inference_qconfig) - if quant_setting: - logger.info("Quantization config settings validated ") - model = convert_fp8_vllm_to_fms_mo(model=model) - else: - sys.exit("Error: This quantization config is wrong/not supported") - embedding_size = model.get_input_embeddings().weight.shape[0] if len(tokenizer) > embedding_size: model.resize_token_embeddings(len(tokenizer)) @@ -154,29 +141,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") - if not inference_qconfig: + quant_mode = check_quantization_setting(model) + + if not quant_mode: logger.info("quantization mode activated, initalizing the qcfg file ") qcfg = qconfig_init(recipe="dq", args=fms_mo_args) else: logger.info("inference mode activated") - if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"): - if fms_mo_args.override_fms_args: - logger.info( - "qcfg file found and some parameters are being over-written " - ) - qcfg = qconfig_init( - recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args - ) - else: - logger.info("qcfg file found, loading the qcfg file ") - qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg") - else: - logger.info( - "qcfg file not found in {model_args.model_name_or_path},\ - loading fms_mo_args and recipe" - ) - qcfg = qconfig_init(recipe="dq", args=fms_mo_args) - qcfg["fp8_inference"] = True + qcfg = load_inference_qconfig_file(model_args, fms_mo_args) + + if quant_mode: + model = convert_fp8_vllm_to_fms_mo(model=model) model_size = model_size_Wb(model, unit="GB") gpu_mem_util_per = model_size / total_gpu_memory @@ -201,7 +176,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["model"] = model_args.model_name_or_path # config layers to skip, smooth scale - if not inference_qconfig: + if not quant_mode: config_quantize_smooth_layers(qcfg) use_dynamo = True @@ -234,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - if not inference_qconfig and qcfg["smoothq"]: + if not quant_mode and qcfg["smoothq"]: scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") if qcfg.get("act_scale_path", None): # user provided a scale file (or a dir) @@ -272,7 +247,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) logger.info(f"Quantized model {model}") logger.info("==" * 20) - if not inference_qconfig: + + if not quant_mode: if qcfg["smoothq"]: logger.info("Starting to apply smooth scale") dq_llm(model, act_scales, qcfg) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index b56525ea..7230bc88 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -395,16 +395,16 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): # Third Party import compressed_tensors - if isinstance( - module, compressed_tensors.linear.compressed_linear.CompressedLinear - ): - pass - else: - logger.warning( - f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." - "Please make sure it doesn't wrap BN and activ func. Otherwise" - "please create an equivalen Linear wrapper and change qcfg['mapping']." - ) + if isinstance( + module, compressed_tensors.linear.compressed_linear.CompressedLinear + ): + pass + else: + logger.warning( + f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." + "Please make sure it doesn't wrap BN and activ func. Otherwise" + "please create an equivalent Linear wrapper and change qcfg['mapping']." + ) QLin = mapping.get(nn.Linear, None) if QLin is None: if verbose: diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index 371632c6..405ddc63 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -237,7 +237,6 @@ def get_weight_quantizer( recompute=False, perGp=None, use_subnormal=False, - emulate=True, ): """Return a quantizer for weight quantization Regular quantizers: @@ -347,7 +346,7 @@ def get_weight_quantizer( weight_quantizer = to_fp8( nbits, q_mode=qw_mode, - emulate=emulate, + emulate=True, perCh=Nch, ) else: diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index e56eb285..b16c3a98 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -29,23 +29,119 @@ import torch # Local +from fms_mo import qconfig_init from fms_mo.quant.quantizers import to_fp8_scaled_perCh from fms_mo.utils.qconfig_utils import get_recipe logger = logging.getLogger(__name__) -def check_quantization_setting(inference: dict = None): +def check_quantization_setting(model: nn.Module = None): """ function checks if the checkpoint is from fp8 quantization """ - return ( - inference["config_groups"]["group_0"]["input_activations"]["num_bits"] == 8 - and inference["config_groups"]["group_0"]["weights"]["num_bits"] == 8 - and inference["config_groups"]["group_0"]["weights"]["type"] == "float" - and inference["config_groups"]["group_0"]["input_activations"]["type"] - == "float" - ) + quant_config = None + if hasattr(model, "config"): + quant_config = model.config.to_dict().get("quantization_config", None) + if quant_config is None: + return False + + logger.info("Validating config settings") + if quant_config["quant_method"] == "compressed-tensors": + if quant_config["format"] != "float-quantized": + raise Exception( + "The input activation and weight quantization dtypes are not supported" + ) + + if ( + quant_config["config_groups"]["group_0"]["input_activations"]["num_bits"] + != 8 + ): + raise Exception("Only 8 bit FP input activation quantization is supported") + + if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8: + raise Exception("Only 8-bit FP weight quantization is supported") + + if quant_config["kv_cache_scheme"] is None: + pass + else: + if quant_config["kv_cache_scheme"]["type"] is not float: + raise Exception("The KV-Cache quantization dtype is not supported") + + if quant_config["kv_cache_scheme"]["num_bits"] != 8: + raise Exception("Only 8-bit KV-Cache quantization dtype is supported") + + return True + + raise Exception("This quantization method is not supported for inferencing") + + +def load_inference_qconfig_file(model_args, fms_mo_args): + """ + Function to load the inference quantization config for fms_mo + """ + if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"): + if fms_mo_args.override_qcfg_args: + logger.info("qcfg file found and some parameters are being over-written") + qcfg = qconfig_init( + recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args + ) + else: + logger.info("qcfg file found, loading the qcfg file ") + qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg") + else: + logger.info( + f"qcfg file not found in {model_args.model_name_or_path},\ + loading fms_mo_args and recipe" + ) + qcfg = qconfig_init(recipe="dq", args=fms_mo_args) + qcfg = update_qcfg_from_model_config(model_args, qcfg) + qcfg["fp8_inference"] = True + + return qcfg + + +def update_qcfg_from_model_config(model_args, qcfg): + """ + function to update the default qcfg setting with settings in the model config file. + Important for the case where qcfg file does not exist. + """ + config = get_recipe(model_args.model_name_or_path + "/config") + if ( + config["quantization_config"]["config_groups"]["group_0"]["input_activations"][ + "strategy" + ] + == "token" + ): + qcfg["qa_mode"] = "fp8_e4m3_scale_perToken" + else: + raise Exception("Only perToken Fp8 activation quantizer is supported") + + if ( + config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] + == "channel" + ): + qcfg["qw_mode"] = "fp8_e4m3_scale_perCh" + elif ( + config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] + == "tensor" + ): + qcfg["qw_mode"] = "fp8_e4m3_scale" + else: + raise Exception( + "Only perChannel or pertensor FP8 quantizers are currently supported" + ) + + qcfg["smoothq"] = False + qcfg["nbits_a"] = config["quantization_config"]["config_groups"]["group_0"][ + "input_activations" + ]["num_bits"] + qcfg["nbits_w"] = config["quantization_config"]["config_groups"]["group_0"][ + "weights" + ]["num_bits"] + qcfg["torch_dtype"] = "float16" + + return qcfg # def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None): diff --git a/pyproject.toml b/pyproject.toml index 5ed690e6..abd8fd1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] -fp8 = ["llmcompressor", "torchao==0.11", "compressed_tensors"] +fp8 = ["llmcompressor", "torchao==0.11"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] From 4878ba1ff7395244c51c72f062ea557b2a041e0a Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Tue, 2 Sep 2025 20:59:30 -0400 Subject: [PATCH 06/12] fix: corrected the inference file Signed-off-by: Omobayode Fagbohungbe --- fms_mo/prep.py | 2 +- fms_mo/utils/dq_inf.py | 16 ++++++++-------- fms_mo/utils/import_utils.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 7230bc88..1d408ec8 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -404,7 +404,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." "Please make sure it doesn't wrap BN and activ func. Otherwise" "please create an equivalent Linear wrapper and change qcfg['mapping']." - ) + ) QLin = mapping.get(nn.Linear, None) if QLin is None: if verbose: diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index b16c3a98..63d6792e 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -49,7 +49,7 @@ def check_quantization_setting(model: nn.Module = None): logger.info("Validating config settings") if quant_config["quant_method"] == "compressed-tensors": if quant_config["format"] != "float-quantized": - raise Exception( + raise ValueError( "The input activation and weight quantization dtypes are not supported" ) @@ -57,23 +57,23 @@ def check_quantization_setting(model: nn.Module = None): quant_config["config_groups"]["group_0"]["input_activations"]["num_bits"] != 8 ): - raise Exception("Only 8 bit FP input activation quantization is supported") + raise ValueError("Only 8 bit FP input activation quantization is supported") if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8: - raise Exception("Only 8-bit FP weight quantization is supported") + raise ValueError("Only 8-bit FP weight quantization is supported") if quant_config["kv_cache_scheme"] is None: pass else: if quant_config["kv_cache_scheme"]["type"] is not float: - raise Exception("The KV-Cache quantization dtype is not supported") + raise ValueError("The KV-Cache quantization dtype is not supported") if quant_config["kv_cache_scheme"]["num_bits"] != 8: - raise Exception("Only 8-bit KV-Cache quantization dtype is supported") + raise ValueError("Only 8-bit KV-Cache quantization dtype is supported") return True - raise Exception("This quantization method is not supported for inferencing") + raise ValueError("This quantization method is not supported for inferencing") def load_inference_qconfig_file(model_args, fms_mo_args): @@ -115,7 +115,7 @@ def update_qcfg_from_model_config(model_args, qcfg): ): qcfg["qa_mode"] = "fp8_e4m3_scale_perToken" else: - raise Exception("Only perToken Fp8 activation quantizer is supported") + raise ValueError("Only perToken Fp8 activation quantizer is supported") if ( config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"] @@ -128,7 +128,7 @@ def update_qcfg_from_model_config(model_args, qcfg): ): qcfg["qw_mode"] = "fp8_e4m3_scale" else: - raise Exception( + raise ValueError( "Only perChannel or pertensor FP8 quantizers are currently supported" ) diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index 6298c353..3d7348f5 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -42,7 +42,7 @@ "torchvision", "huggingface_hub", "torchao", - "compressed_tensors", + #"compressed_tensors", ] available_packages = {} From 621769995f36cb7566474ac3b87ca2252fc1bbb2 Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Tue, 2 Sep 2025 21:14:37 -0400 Subject: [PATCH 07/12] fix: corrected the lint error Signed-off-by: Omobayode Fagbohungbe --- fms_mo/utils/import_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index 3d7348f5..a695e39d 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -42,7 +42,6 @@ "torchvision", "huggingface_hub", "torchao", - #"compressed_tensors", ] available_packages = {} From a2ae168b448a8d2ee99c67e0e01ecff37fbb9df5 Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Tue, 2 Sep 2025 21:21:06 -0400 Subject: [PATCH 08/12] fix: corrected the ruff error Signed-off-by: Omobayode Fagbohungbe --- fms_mo/prep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 1d408ec8..8fdb7ac4 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -410,7 +410,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): if verbose: logger.info( f"Skip quantization of {curr_full_name} - mapping of Linear is None" - ) + ) return module_output # None means no swap for this type module_output = QLin( From aca818a72cebdfeac119fd808689e33d57f591ee Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Tue, 2 Sep 2025 21:26:08 -0400 Subject: [PATCH 09/12] fix:minor edit on qmodel_prep Signed-off-by: Omobayode Fagbohungbe --- fms_mo/prep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 8fdb7ac4..7230bc88 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -404,13 +404,13 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." "Please make sure it doesn't wrap BN and activ func. Otherwise" "please create an equivalent Linear wrapper and change qcfg['mapping']." - ) + ) QLin = mapping.get(nn.Linear, None) if QLin is None: if verbose: logger.info( f"Skip quantization of {curr_full_name} - mapping of Linear is None" - ) + ) return module_output # None means no swap for this type module_output = QLin( From fbdf19f63cb57362091a18cf5f311fdc7b59fb79 Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Fri, 5 Sep 2025 18:30:53 -0400 Subject: [PATCH 10/12] fix: type hinting arguments and returns Signed-off-by: Omobayode Fagbohungbe --- fms_mo/dq.py | 12 ++--- fms_mo/prep.py | 8 +-- fms_mo/training_args.py | 3 +- fms_mo/utils/dq_inf.py | 112 ++++++++++++++++++++++++---------------- 4 files changed, 78 insertions(+), 57 deletions(-) diff --git a/fms_mo/dq.py b/fms_mo/dq.py index 6c28fea0..2b4bdfa5 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -141,16 +141,16 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") - quant_mode = check_quantization_setting(model) + inference_only = check_quantization_setting(model) - if not quant_mode: + if not inference_only: logger.info("quantization mode activated, initalizing the qcfg file ") qcfg = qconfig_init(recipe="dq", args=fms_mo_args) else: logger.info("inference mode activated") qcfg = load_inference_qconfig_file(model_args, fms_mo_args) - if quant_mode: + if inference_only: model = convert_fp8_vllm_to_fms_mo(model=model) model_size = model_size_Wb(model, unit="GB") @@ -176,7 +176,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["model"] = model_args.model_name_or_path # config layers to skip, smooth scale - if not quant_mode: + if not inference_only: config_quantize_smooth_layers(qcfg) use_dynamo = True @@ -209,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): ) # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. - if not quant_mode and qcfg["smoothq"]: + if not inference_only and qcfg["smoothq"]: scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") if qcfg.get("act_scale_path", None): # user provided a scale file (or a dir) @@ -248,7 +248,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Quantized model {model}") logger.info("==" * 20) - if not quant_mode: + if not inference_only: if qcfg["smoothq"]: logger.info("Starting to apply smooth scale") dq_llm(model, act_scales, qcfg) diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 7230bc88..410d885f 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -394,17 +394,17 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): if available_packages["compressed_tensors"]: # Third Party import compressed_tensors - - if isinstance( + # checks if the layer is CompressedLinear. If it is a CompressedLinear layer, + # it does nothing. Otherwise, it throws the warning sign + if not isinstance( module, compressed_tensors.linear.compressed_linear.CompressedLinear ): - pass - else: logger.warning( f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." "Please make sure it doesn't wrap BN and activ func. Otherwise" "please create an equivalent Linear wrapper and change qcfg['mapping']." ) + QLin = mapping.get(nn.Linear, None) if QLin is None: if verbose: diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index 66d230fb..01e37377 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -209,8 +209,7 @@ class FMSMOArguments(TypeChecker): default=False, metadata={"help": "Apply recomputation during checkpoint saving for AIU."}, ) - fp8_use_subnormal: bool = field(default=False) - override_fms_args: bool = field(default=False) + override_qcfg_args: bool = field(default=False) @dataclass diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index 63d6792e..6550eb7a 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -17,6 +17,7 @@ """ # Standard +from typing import Any, Dict, List, Tuple, Union import glob import json import logging @@ -36,7 +37,7 @@ logger = logging.getLogger(__name__) -def check_quantization_setting(model: nn.Module = None): +def check_quantization_setting(model: nn.Module) -> bool: """ function checks if the checkpoint is from fp8 quantization """ @@ -47,36 +48,49 @@ def check_quantization_setting(model: nn.Module = None): return False logger.info("Validating config settings") - if quant_config["quant_method"] == "compressed-tensors": - if quant_config["format"] != "float-quantized": - raise ValueError( - "The input activation and weight quantization dtypes are not supported" - ) - - if ( - quant_config["config_groups"]["group_0"]["input_activations"]["num_bits"] - != 8 - ): - raise ValueError("Only 8 bit FP input activation quantization is supported") - - if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8: - raise ValueError("Only 8-bit FP weight quantization is supported") - - if quant_config["kv_cache_scheme"] is None: - pass - else: - if quant_config["kv_cache_scheme"]["type"] is not float: - raise ValueError("The KV-Cache quantization dtype is not supported") - - if quant_config["kv_cache_scheme"]["num_bits"] != 8: - raise ValueError("Only 8-bit KV-Cache quantization dtype is supported") - - return True + if "quant_method" in quant_config.keys(): + if quant_config["quant_method"] == "compressed-tensors": + if quant_config["format"] != "float-quantized": + raise ValueError( + "The input activation and weight quantization dtypes are not supported" + ) + + if ( + quant_config["config_groups"]["group_0"]["input_activations"][ + "num_bits" + ] + != 8 + ): + raise ValueError( + "Only 8 bit FP input activation quantization is supported" + ) + + if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8: + raise ValueError("Only 8-bit FP weight quantization is supported") + + if quant_config["kv_cache_scheme"] is not None: + if quant_config["kv_cache_scheme"]["type"] is not float: + raise ValueError("The KV-Cache quantization dtype is not supported") + + if quant_config["kv_cache_scheme"]["num_bits"] != 8: + raise ValueError( + "Only 8-bit KV-Cache quantization dtype is supported" + ) + + return True + raise ValueError( + "The quantization method is not supported for inferencing." + "Only Fp8 quantization is supported" + ) - raise ValueError("This quantization method is not supported for inferencing") + raise ValueError( + "The quantization method is not found. Please check the config file" + ) -def load_inference_qconfig_file(model_args, fms_mo_args): +def load_inference_qconfig_file( + model_args: Any = None, fms_mo_args: Any = None +) -> Dict[str, Union[int, float, str]]: """ Function to load the inference quantization config for fms_mo """ @@ -87,12 +101,13 @@ def load_inference_qconfig_file(model_args, fms_mo_args): recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args ) else: - logger.info("qcfg file found, loading the qcfg file ") + logger.info(f"loading quantization configuration from\ + {model_args.model_name_or_path + '/qcfg.json'}") qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg") else: logger.info( - f"qcfg file not found in {model_args.model_name_or_path},\ - loading fms_mo_args and recipe" + f"qcfg file not found in {model_args.model_name_or_path}," + "loading fms_mo_args and recipe" ) qcfg = qconfig_init(recipe="dq", args=fms_mo_args) qcfg = update_qcfg_from_model_config(model_args, qcfg) @@ -101,7 +116,9 @@ def load_inference_qconfig_file(model_args, fms_mo_args): return qcfg -def update_qcfg_from_model_config(model_args, qcfg): +def update_qcfg_from_model_config( + model_args: Any = None, qcfg: dict = None +) -> Dict[str, Union[int, float, str]]: """ function to update the default qcfg setting with settings in the model config file. Important for the case where qcfg file does not exist. @@ -144,15 +161,16 @@ def update_qcfg_from_model_config(model_args, qcfg): return qcfg -# def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None): -def rename_fms_dict_to_vllm_dict(model_dict: dict = None): +def rename_fms_dict_to_vllm_dict( + model_dict: dict = None, +) -> Tuple[Dict[str, Union[int, float]], Dict[str, Union[int, float]]]: """ Function to rename the dict in fms_mo format to vllm_format. """ st_dict = {} fms_dict = {} keys = model_dict.keys() - + logger.info("WARNING: only static weights per-channel is supported at this time") for k, v in model_dict.items(): if ".weight" in k: key = k.split("weight")[0] @@ -167,7 +185,9 @@ def rename_fms_dict_to_vllm_dict(model_dict: dict = None): return st_dict, fms_dict -def update_config(model_config_file: dict = None, qcfg: dict = None): +def update_config( + model_config_file: dict = None, qcfg: dict = None +) -> Dict[str, Union[int, str]]: """ Function to update the model config file with quantization configuration """ @@ -181,7 +201,9 @@ def update_config(model_config_file: dict = None, qcfg: dict = None): return model_config_file -def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer=None, folder: str = None): +def save_vllm_fp8( + model: nn.Module, qcfg: dict, tokenizer=None, folder: str = None +) -> None: """ Function to save fp8 DQ model in vllm fp8 format """ @@ -200,7 +222,9 @@ def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer=None, folder: str = No json.dump(config, f, indent=4) -def convert_fms_mo_to_vllm_fp8_format(checkpoint: str = None, folder: str = None): +def convert_fms_mo_to_vllm_fp8_format( + checkpoint: str = None, folder: str = None +) -> None: """ Function to convert fp8 fms_mo DQ model checkpoint to vllm fp8 format """ @@ -231,7 +255,7 @@ def convert_fms_mo_to_vllm_fp8_format(checkpoint: str = None, folder: str = None json.dump(config, f, indent=4) -def find_file_glob(pattern: str, search_path: str): +def find_file_glob(pattern: str, search_path: str) -> List[str]: """ Finds files matching a pattern within a directory and its subdirectories. """ @@ -243,7 +267,7 @@ def find_file_glob(pattern: str, search_path: str): def convert_fp8_vllm_dict_to_fms_mo_dict( checkpoint: str = None, output_dir: str = None -): +) -> None: """ Function to help convert vllm fp8 checkpoint into fms_mo fp8 format """ @@ -257,7 +281,7 @@ def convert_fp8_vllm_dict_to_fms_mo_dict( save_torch_state_dict(fms_mo_dict, output_dir) -def rename_vllm_dict_to_fms_mo(vllm_dict: dict = None): +def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict: """ Function to help rename vllm dict format to fms_mo dict format """ @@ -271,14 +295,12 @@ def rename_vllm_dict_to_fms_mo(vllm_dict: dict = None): fms_mo_dict[k] = v else: key = k.split("weight")[0] - if key + "weight_scale" in vllm_dict.keys(): - pass - else: + if key + "weight_scale" not in vllm_dict.keys(): fms_mo_dict[k] = v return fms_mo_dict -def convert_fp8_vllm_to_fms_mo(model: nn.Module = None): +def convert_fp8_vllm_to_fms_mo(model: nn.Module = None) -> nn.Module: """ Function to help convert fp8 vllm model dict format to fms_mo fp8 format """ From d3e7c6163f9e8f23ce061661ab4d35f306a7886b Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Wed, 10 Sep 2025 08:23:05 -0400 Subject: [PATCH 11/12] fix: improving argument hints and inferencing for models with skipped layers Signed-off-by: Omobayode Fagbohungbe --- fms_mo/utils/dq_inf.py | 24 +++++++++++++++--------- fms_mo/utils/import_utils.py | 1 + 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index 6550eb7a..99194955 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -17,7 +17,7 @@ """ # Standard -from typing import Any, Dict, List, Tuple, Union +from typing import Any import glob import json import logging @@ -90,7 +90,7 @@ def check_quantization_setting(model: nn.Module) -> bool: def load_inference_qconfig_file( model_args: Any = None, fms_mo_args: Any = None -) -> Dict[str, Union[int, float, str]]: +) -> dict[str, int | float | str]: """ Function to load the inference quantization config for fms_mo """ @@ -118,7 +118,7 @@ def load_inference_qconfig_file( def update_qcfg_from_model_config( model_args: Any = None, qcfg: dict = None -) -> Dict[str, Union[int, float, str]]: +) -> dict[str, int | float | str]: """ function to update the default qcfg setting with settings in the model config file. Important for the case where qcfg file does not exist. @@ -157,13 +157,18 @@ def update_qcfg_from_model_config( "weights" ]["num_bits"] qcfg["torch_dtype"] = "float16" - + if config["quantization_config"]["ignore"] is not []: + qcfg["qskip_layer_name"] = config["quantization_config"]["ignore"] + qcfg["qskip_large_mag_layers"] = True + else: + qcfg["qskip_layer_name"] = [] + qcfg["qskip_large_mag_layers"] = False return qcfg def rename_fms_dict_to_vllm_dict( model_dict: dict = None, -) -> Tuple[Dict[str, Union[int, float]], Dict[str, Union[int, float]]]: +) -> tuple[dict[str, float | int], dict[str, float | int]]: """ Function to rename the dict in fms_mo format to vllm_format. """ @@ -187,7 +192,7 @@ def rename_fms_dict_to_vllm_dict( def update_config( model_config_file: dict = None, qcfg: dict = None -) -> Dict[str, Union[int, str]]: +) -> dict[str, float | int | str]: """ Function to update the model config file with quantization configuration """ @@ -196,7 +201,8 @@ def update_config( data["quantization_config"]["config_groups"]["group_0"]["weights"] = ( "{num_bits: 8, type: float, symmetric: true, strategy: tensor}" ) - + if qcfg["qskip_large_mag_layers"] == True: + data["quantization_config"]["ignore"] = qcfg["qskip_layer_name"] model_config_file.update(data) return model_config_file @@ -255,7 +261,7 @@ def convert_fms_mo_to_vllm_fp8_format( json.dump(config, f, indent=4) -def find_file_glob(pattern: str, search_path: str) -> List[str]: +def find_file_glob(pattern: str, search_path: str) -> list[str]: """ Finds files matching a pattern within a directory and its subdirectories. """ @@ -281,7 +287,7 @@ def convert_fp8_vllm_dict_to_fms_mo_dict( save_torch_state_dict(fms_mo_dict, output_dir) -def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict: +def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict[str, float | int | str]: """ Function to help rename vllm dict format to fms_mo dict format """ diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index a695e39d..3490564c 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -42,6 +42,7 @@ "torchvision", "huggingface_hub", "torchao", + "compressed_tensors" ] available_packages = {} From 751077044c9c0a53f50943852fbc0904d5a149a6 Mon Sep 17 00:00:00 2001 From: Omobayode Fagbohungbe Date: Wed, 10 Sep 2025 09:26:58 -0400 Subject: [PATCH 12/12] fix: correcting lint error Signed-off-by: Omobayode Fagbohungbe --- fms_mo/utils/dq_inf.py | 4 ++-- fms_mo/utils/import_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_mo/utils/dq_inf.py b/fms_mo/utils/dq_inf.py index 99194955..9f44d313 100644 --- a/fms_mo/utils/dq_inf.py +++ b/fms_mo/utils/dq_inf.py @@ -157,7 +157,7 @@ def update_qcfg_from_model_config( "weights" ]["num_bits"] qcfg["torch_dtype"] = "float16" - if config["quantization_config"]["ignore"] is not []: + if config["quantization_config"]["ignore"] != []: qcfg["qskip_layer_name"] = config["quantization_config"]["ignore"] qcfg["qskip_large_mag_layers"] = True else: @@ -201,7 +201,7 @@ def update_config( data["quantization_config"]["config_groups"]["group_0"]["weights"] = ( "{num_bits: 8, type: float, symmetric: true, strategy: tensor}" ) - if qcfg["qskip_large_mag_layers"] == True: + if qcfg["qskip_large_mag_layers"] is True: data["quantization_config"]["ignore"] = qcfg["qskip_layer_name"] model_config_file.update(data) return model_config_file diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index 3490564c..6298c353 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -42,7 +42,7 @@ "torchvision", "huggingface_hub", "torchao", - "compressed_tensors" + "compressed_tensors", ] available_packages = {}