Skip to content
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ ignored-modules=gptqmodel,
llmcompressor,
cutlass_mm,
pygraphviz,
matplotlib
matplotlib,
compressed_tensors

# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
Expand Down
113 changes: 70 additions & 43 deletions fms_mo/dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
get_act_scales_1gpu,
)
from fms_mo.utils.aiu_utils import save_for_aiu
from fms_mo.utils.dq_inf import (
check_quantization_setting,
convert_fp8_vllm_to_fms_mo,
load_inference_qconfig_file,
save_vllm_fp8,
)
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
Expand Down Expand Up @@ -134,7 +140,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
logger.info(f"Initialized model is: \n {model}")
logger.info(f"Model is at {model.device} after intialization")
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)

quant_mode = check_quantization_setting(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quant_mode is boolean but it isn't clear from the name of this variable what its role is (to me, it implies a quantization strategy). I think do_quantization or run_quantization or (with opposite meaning) inference_only would be clearer names. Consider updating this name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to inference_only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected


if not quant_mode:
logger.info("quantization mode activated, initalizing the qcfg file ")
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
else:
logger.info("inference mode activated")
qcfg = load_inference_qconfig_file(model_args, fms_mo_args)

if quant_mode:
model = convert_fp8_vllm_to_fms_mo(model=model)

model_size = model_size_Wb(model, unit="GB")
gpu_mem_util_per = model_size / total_gpu_memory
Expand All @@ -159,7 +176,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):

qcfg["model"] = model_args.model_name_or_path
# config layers to skip, smooth scale
config_quantize_smooth_layers(qcfg)
if not quant_mode:
config_quantize_smooth_layers(qcfg)

use_dynamo = True
# use dynamo as default unless really needed, False -> fallback to TorchScript tracing
Expand All @@ -178,6 +196,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
qcfg["model"] = model_args.model_name_or_path
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg
qcfg["plotsvg"] = False
qcfg["output_folder"] = opt_args.output_dir

calibration_dataset = load_from_disk(data_args.training_data_path)
calibration_dataset = calibration_dataset.with_format("torch")
Expand All @@ -190,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
)

# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
if qcfg["smoothq"]:
if not quant_mode and qcfg["smoothq"]:
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
if qcfg.get("act_scale_path", None):
# user provided a scale file (or a dir)
Expand Down Expand Up @@ -229,48 +248,56 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
logger.info(f"Quantized model {model}")
logger.info("==" * 20)

if qcfg["smoothq"]:
logger.info("Starting to apply smooth scale")
dq_llm(model, act_scales, qcfg)
logger.info("Finished applying smooth scale")
if not quant_mode:
if qcfg["smoothq"]:
logger.info("Starting to apply smooth scale")
dq_llm(model, act_scales, qcfg)
logger.info("Finished applying smooth scale")

if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
else:
model.to("cuda")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
total=qcfg["qmodel_calibration_new"],
if qcfg["qmodel_calibration_new"] > 0:
logger.info("Starting to calibrate activation clip_val")
if qcfg["large_model"]:
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
else:
model.to("cuda")
pbar = tqdm(
dq_dataloader,
desc=" calibration after applying smoothq scale and before inference",
total=qcfg["qmodel_calibration_new"],
)
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
data_mb = prepare_input(model.device, data_mb)
with patch_torch_bmm(qcfg):
model(**data_mb)

if opt_args.save_ckpt_for_aiu:
logger.info(
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
)
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
elif not opt_args.save_ckpt:
logger.info(
f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}"
)
save_vllm_fp8(model, qcfg, tokenizer, opt_args.output_dir)
elif opt_args.save_ckpt:
logger.info(
f"Saving quantized model and tokenizer to {opt_args.output_dir}"
)
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
tokenizer.save_pretrained(opt_args.output_dir)

if fms_mo_args.aiu_sim_triton:
# NOTE plz apply correct HW settings here, defaults are not real HW params
lower_qmodel_triton(
model,
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
max_acc_bits=qcfg.get("max_acc_bits", 32),
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
chunk_size=qcfg.get("chunk_size", 32), # 1024
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
# layer_to_exclude=["lm_head",]
)
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
data_mb = prepare_input(model.device, data_mb)
with patch_torch_bmm(qcfg):
model(**data_mb)

if opt_args.save_ckpt_for_aiu:
logger.info(
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
)
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
elif opt_args.save_ckpt:
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
tokenizer.save_pretrained(opt_args.output_dir)

if fms_mo_args.aiu_sim_triton:
# NOTE plz apply correct HW settings here, defaults are not real HW params
lower_qmodel_triton(
model,
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
max_acc_bits=qcfg.get("max_acc_bits", 32),
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
chunk_size=qcfg.get("chunk_size", 32), # 1024
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
# layer_to_exclude=["lm_head",]
)

if fms_mo_args.eval_ppl:
path_test = Path(data_args.test_data_path)
Expand Down
66 changes: 59 additions & 7 deletions fms_mo/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,20 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
# For nn.Linear
elif isinstance(module, nn.Linear):
if module.__class__ != nn.Linear:
logger.warning(
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
"Please make sure it doesn't wrap BN and activ func."
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."
)
if available_packages["compressed_tensors"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is incorrect because if compressed_tensors is not installed, the warning is never triggered.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can incorporate the 3 nested ifs into a single check:

if (
    module.__class__ != nn.Linear
    and (
         not available_packages["compressed_tensors"]
         or not isinstance(
             module, compressed_tensors.linear.compressed_linear.CompressedLinear
         )
     )
):
    logger.warning(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you set it up this way, add a comment to explain when this if clause is triggered because it is not immediate to understand

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet fixed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now

# Third Party
import compressed_tensors

if isinstance(
module, compressed_tensors.linear.compressed_linear.CompressedLinear
):
pass
else:
logger.warning(
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
"Please make sure it doesn't wrap BN and activ func. Otherwise"
"please create an equivalent Linear wrapper and change qcfg['mapping']."
)
QLin = mapping.get(nn.Linear, None)
if QLin is None:
if verbose:
Expand Down Expand Up @@ -571,6 +579,42 @@ def has_quantized_module(model):
return any(isinstance(m, quantized_modules) for m in model.modules())


def swap_qbmm(model: nn.Module, qcfg: dict):
"""Go through all model.named_modules(), try to create an equivalent
Qbmm layer to replace each of the existing linear Bmm layers.
Copy link
Collaborator

@chichun-charlie-liu chichun-charlie-liu Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this func name and description seem inaccurate and may cause confusion. what this func does is to create and attach new Qbmm modules to a module where torch.matmul/torch.bmm is being used. Has nothing to do with "swap" and "replace Linear BMM layers"...


Args:
model (nn.Module): input model to be "prepared"
qcfg (dict): quant config

Returns: updated model is returned with the Qbmm added

"""

# Local
from fms_mo.modules import QBmm

qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"]["which2patch_contextmanager"]
isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm"
for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
mod_bmm_happened = model.get_submodule(mod_name)
for whichQBmm, ln in enumerate(line_nums, start=1):
nbits = qcfg[f"nbits_bmm{whichQBmm}"]
newQBmm = QBmm(
num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits,
num_bits_m2=nbits,
qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"],
qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"],
m1_unidirectional=(whichQBmm == 2),
m1_bounded=(whichQBmm == 2), # see Note 5
m2_unidirectional=False,
m2_bounded=False,
replaceBmm=isbmm,
qcfg=qcfg,
)
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)


def qmodel_prep(
model,
dloader,
Expand Down Expand Up @@ -657,6 +701,12 @@ def qmodel_prep(
Returns:
nn.Module: quantized model ready for further PTQ/QAT
"""
if qcfg["fp8_inference"]:
if qcfg.get("QBmm"):
swap_qbmm(model, qcfg)

model = q_any_net_5(model, qcfg, verbose=False)
return model

sys.setrecursionlimit(4000)

Expand Down Expand Up @@ -906,8 +956,10 @@ def qmodel_prep(
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=DPorDDPdevices
)

qconfig_save(qcfg, fname="qcfg.json")
if qcfg["output_folder"] is None:
qconfig_save(qcfg, fname="qcfg.json")
else:
qconfig_save(qcfg, fname=qcfg["output_folder"] + "/qcfg.json")
qcfg["tb_writer"] = tb_writer

logger.info(f"--- Quantized model --- \n{model}\n")
Expand Down
4 changes: 3 additions & 1 deletion fms_mo/recipes/dq.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
"eval_ckpt": true,
"nbits_bmm1" : 32,
"nbits_bmm2" : 32,
"nbits_kvcache" : 32
"nbits_kvcache" : 32,
"fp8_inference": false,
"output_folder": null
}
44 changes: 44 additions & 0 deletions fms_mo/recipes/fp8_vllm_quantization_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"quantization_config": {
"config_groups": {
"group_0": {
"input_activations": {
"actorder": null,
"block_structure": null,
"dynamic": true,
"group_size": null,
"num_bits": 8,
"observer": null,
"observer_kwargs": {},
"strategy": "token",
"symmetric": true,
"type": "float"
},
"output_activations": null,
"targets": [
"Linear"
],
"weights": {
"actorder": null,
"block_structure": null,
"dynamic": false,
"group_size": null,
"num_bits": 8,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "channel",
"symmetric": true,
"type": "float"
}
}
},
"format": "float-quantized",
"global_compression_ratio": null,
"ignore": [
"lm_head"
],
"kv_cache_scheme": null,
"quant_method": "compressed-tensors",
"quantization_status": "compressed"
}
}
2 changes: 2 additions & 0 deletions fms_mo/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class FMSMOArguments(TypeChecker):
default=False,
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
)
fp8_use_subnormal: bool = field(default=False)
override_fms_args: bool = field(default=False)


@dataclass
Expand Down
Loading
Loading