Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 90 additions & 47 deletions fms_mo/dq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright The FMS Model Optimizer Authors
#

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

# http://www.apache.org/licenses/LICENSE-2.0
#

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -34,6 +34,7 @@
)
import torch

import os
# Local
from fms_mo import qconfig_init, qmodel_prep
from fms_mo.custom_ext_kernels.utils import (
Expand All @@ -50,6 +51,11 @@
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
from fms_mo.utils.dq_inf import (
save_vllm_fp8,
convert_fp8_vllm_to_fms_mo,
check_quantization_setting,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,14 +133,42 @@
low_cpu_mem_usage=bool(model_args.device_map),
)

inference= model.config.to_dict().get("quantization_config",None)

if inference:
quant_setting = check_quantization_setting(inference)
if quant_setting:
logger.info("Quantization config settings validated ")
model = convert_fp8_vllm_to_fms_mo(model = model)
else:
exit("__This quantization config is wrong/not supported__")

Check warning on line 144 in fms_mo/dq.py

View workflow job for this annotation

GitHub Actions / lint: pylint

R1722: Consider using 'sys.exit' instead (consider-using-sys-exit)


embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))

logger.info(f"Initialized model is: \n {model}")
logger.info(f"Model is at {model.device} after intialization")
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)

if not inference:
logger.info("quantization mode activated, initalizing the qcfg file ")
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
else:
logger.info("inference mode activated")
if os.path.isfile(model_args.model_name_or_path+"/qcfg.json"):
if fms_mo_args.override_fms_args:
logger.info("qcfg file found and some parameters are being over-written ")
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg", args=fms_mo_args)
else:
logger.info("qcfg file found, loading the qcfg file ")
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg")
else:
logger.info("qcfg file not found in {model_args.model_name_or_path},\
loading fms_mo_args and recipe"
)
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)

model_size = model_size_Wb(model, unit="GB")
gpu_mem_util_per = model_size / total_gpu_memory
Expand Down Expand Up @@ -178,6 +212,7 @@
qcfg["model"] = model_args.model_name_or_path
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg
qcfg["plotsvg"] = False
qcfg["output_folder"] = opt_args.output_dir

calibration_dataset = load_from_disk(data_args.training_data_path)
calibration_dataset = calibration_dataset.with_format("torch")
Expand All @@ -190,7 +225,7 @@
)

# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
if qcfg["smoothq"]:
if not inference and qcfg["smoothq"] :
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
if qcfg.get("act_scale_path", None):
# user provided a scale file (or a dir)
Expand Down Expand Up @@ -224,53 +259,61 @@
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
use_dynamo=use_dynamo,
dev=dev,
mode=inference,
save_fname="dq",
)
logger.info(f"Quantized model {model}")
logger.info("==" * 20)

if qcfg["smoothq"]:
logger.info("Starting to apply smooth scale")
dq_llm(model, act_scales, qcfg)
logger.info("Finished applying smooth scale")

if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
else:
model.to("cuda")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
total=qcfg["qmodel_calibration_new"],
if not inference:
if qcfg["smoothq"]:
logger.info("Starting to apply smooth scale")
dq_llm(model, act_scales, qcfg)
logger.info("Finished applying smooth scale")

if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
else:
model.to("cuda")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
total=qcfg["qmodel_calibration_new"],
)
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
data_mb = prepare_input(model.device, data_mb)
with patch_torch_bmm(qcfg):
model(**data_mb)

if opt_args.save_ckpt_for_aiu:
logger.info(
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
)
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
elif not opt_args.save_ckpt:
logger.info(
f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}"
)
save_vllm_fp8(model,qcfg,tokenizer,opt_args.output_dir)
elif opt_args.save_ckpt:
logger.info(
f"Saving quantized model and tokenizer to {opt_args.output_dir}"
)
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
tokenizer.save_pretrained(opt_args.output_dir)

if fms_mo_args.aiu_sim_triton:
# NOTE plz apply correct HW settings here, defaults are not real HW params
lower_qmodel_triton(
model,
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
max_acc_bits=qcfg.get("max_acc_bits", 32),
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
chunk_size=qcfg.get("chunk_size", 32), # 1024
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
# layer_to_exclude=["lm_head",]
)
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
data_mb = prepare_input(model.device, data_mb)
with patch_torch_bmm(qcfg):
model(**data_mb)

if opt_args.save_ckpt_for_aiu:
logger.info(
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
)
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
elif opt_args.save_ckpt:
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
tokenizer.save_pretrained(opt_args.output_dir)

if fms_mo_args.aiu_sim_triton:
# NOTE plz apply correct HW settings here, defaults are not real HW params
lower_qmodel_triton(
model,
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
max_acc_bits=qcfg.get("max_acc_bits", 32),
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
chunk_size=qcfg.get("chunk_size", 32), # 1024
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
# layer_to_exclude=["lm_head",]
)

if fms_mo_args.eval_ppl:
path_test = Path(data_args.test_data_path)
Expand Down
2 changes: 1 addition & 1 deletion fms_mo/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@
)

# pylint: disable=not-callable

Check warning on line 284 in fms_mo/modules/linear.py

View workflow job for this annotation

GitHub Actions / lint: pylint

C0303: Trailing whitespace (trailing-whitespace)
return F.linear(x, self.W_fp, self.bias)
else:
qinput = self.quantize_feature(x / scale).to(x.dtype)
Expand All @@ -296,7 +297,6 @@
)

qbias = self.bias

# pylint: disable=not-callable
output = F.linear(qinput, qweight, qbias)

Expand Down
61 changes: 53 additions & 8 deletions fms_mo/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Third Party
from torch import nn
import torch

import compressed_tensors

Check failure on line 26 in fms_mo/prep.py

View workflow job for this annotation

GitHub Actions / lint: pylint

E0401: Unable to import 'compressed_tensors' (import-error)
# Local
from fms_mo.calib import qmodel_calib
from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules
Expand Down Expand Up @@ -391,12 +391,14 @@
# For nn.Linear
elif isinstance(module, nn.Linear):
if module.__class__ != nn.Linear:
logger.warning(
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
"Please make sure it doesn't wrap BN and activ func."
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."
)

if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear):
pass
else:
logger.warning(
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
"Please make sure it doesn't wrap BN and activ func."
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."

Check warning on line 400 in fms_mo/prep.py

View workflow job for this annotation

GitHub Actions / lint: pylint

C0301: Line too long (101/100) (line-too-long)
)
QLin = mapping.get(nn.Linear, None)
if QLin is None:
if verbose:
Expand Down Expand Up @@ -570,6 +572,41 @@
"""Check if model is already quantized - do not want to quantize twice if so"""
return any(isinstance(m, quantized_modules) for m in model.modules())

def swap_qbmm(model: nn.Module, qcfg: dict):
"""Go through all model.named_modules(), try to create an equivalent
Qbmm layer to replace each of the existing linear Bmm layers.
Copy link
Collaborator

@chichun-charlie-liu chichun-charlie-liu Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this func name and description seem inaccurate and may cause confusion. what this func does is to create and attach new Qbmm modules to a module where torch.matmul/torch.bmm is being used. Has nothing to do with "swap" and "replace Linear BMM layers"...


Args:
model (nn.Module): input model to be "prepared"
qcfg (dict): quant config

Returns: updated model is returned with the Qbmm added

"""

from fms_mo.modules import QBmm

qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
"which2patch_contextmanager"
]
isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm"
for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
mod_bmm_happened = model.get_submodule(mod_name)
for whichQBmm, ln in enumerate(line_nums, start=1):
nbits = qcfg[f"nbits_bmm{whichQBmm}"]
newQBmm = QBmm(
num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits,
num_bits_m2=nbits,
qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"],
qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"],
m1_unidirectional=(whichQBmm == 2),
m1_bounded=(whichQBmm == 2), # see Note 5
m2_unidirectional=False,
m2_bounded=False,
replaceBmm=isbmm,
qcfg=qcfg,
)
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)

def qmodel_prep(
model,
Expand All @@ -582,6 +619,7 @@
Qcali=False,
dev=None,
use_dynamo=False,
mode=False,
verbose=False,
**kwargs,
):
Expand Down Expand Up @@ -657,7 +695,14 @@
Returns:
nn.Module: quantized model ready for further PTQ/QAT
"""
if mode:

Check warning on line 699 in fms_mo/prep.py

View workflow job for this annotation

GitHub Actions / lint: pylint

C0303: Trailing whitespace (trailing-whitespace)
if qcfg.get("QBmm"):

Check warning on line 700 in fms_mo/prep.py

View workflow job for this annotation

GitHub Actions / lint: pylint

C0303: Trailing whitespace (trailing-whitespace)
swap_qbmm(model,qcfg)

model = q_any_net_5(model, qcfg, verbose = False)
return model

Check warning on line 705 in fms_mo/prep.py

View workflow job for this annotation

GitHub Actions / lint: pylint

C0303: Trailing whitespace (trailing-whitespace)
sys.setrecursionlimit(4000)

currDev = next(model.parameters()).device if dev is None else dev
Expand Down Expand Up @@ -907,7 +952,7 @@
model, device_ids=DPorDDPdevices
)

qconfig_save(qcfg, fname="qcfg.json")
qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json")
qcfg["tb_writer"] = tb_writer

logger.info(f"--- Quantized model --- \n{model}\n")
Expand Down
3 changes: 2 additions & 1 deletion fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def get_weight_quantizer(
recompute=False,
perGp=None,
use_subnormal=False,
emulate = True,
):
"""Return a quantizer for weight quantization
Regular quantizers:
Expand Down Expand Up @@ -346,7 +347,7 @@ def get_weight_quantizer(
weight_quantizer = to_fp8(
nbits,
q_mode=qw_mode,
emulate=True,
emulate=emulate,
perCh=Nch,
)
else:
Expand Down
44 changes: 44 additions & 0 deletions fms_mo/recipes/quant.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"quantization_config": {
"config_groups": {
"group_0": {
"input_activations": {
"actorder": null,
"block_structure": null,
"dynamic": true,
"group_size": null,
"num_bits": 8,
"observer": null,
"observer_kwargs": {},
"strategy": "token",
"symmetric": true,
"type": "float"
},
"output_activations": null,
"targets": [
"Linear"
],
"weights": {
"actorder": null,
"block_structure": null,
"dynamic": false,
"group_size": null,
"num_bits": 8,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "channel",
"symmetric": true,
"type": "float"
}
}
},
"format": "float-quantized",
"global_compression_ratio": null,
"ignore": [
"lm_head"
],
"kv_cache_scheme": null,
"quant_method": "compressed-tensors",
"quantization_status": "compressed"
}
}
2 changes: 2 additions & 0 deletions fms_mo/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class FMSMOArguments(TypeChecker):
default=False,
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
)
fp8_use_subnormal: bool = field(default=False)
override_fms_args: bool = field(default=False)


@dataclass
Expand Down
Loading
Loading