-
Notifications
You must be signed in to change notification settings - Fork 16
feat: fast loading and saving functionality #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
| "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." | ||
| ) | ||
|
|
||
| if available_packages["compressed_tensors"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this block is incorrect because if compressed_tensors is not installed, the warning is never triggered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can incorporate the 3 nested ifs into a single check:
if (
module.__class__ != nn.Linear
and (
not available_packages["compressed_tensors"]
or not isinstance(
module, compressed_tensors.linear.compressed_linear.CompressedLinear
)
)
):
logger.warning(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you set it up this way, add a comment to explain when this if clause is triggered because it is not immediate to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not yet fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed now
pyproject.toml
Outdated
| [project.optional-dependencies] | ||
| examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] | ||
| fp8 = ["llmcompressor", "torchao==0.11"] | ||
| fp8 = ["llmcompressor", "torchao==0.11", "compressed_tensors"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to be careful with this because it adds another requirement to our FP8 configuration, which would only be used in the specific scenario of loading an llm_compressor model back into fms-mo for evaluation.
This is unless compressed_tensors is already a requirement of llmcompressor, in which case this additional import is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following up on this to confirm that compressed_tensors is already a dependency of llmcompressor:
from llm-compressor/setup.py:
install_requires=[
...
"compressed-tensors==0.11.0"
if BUILD_TYPE == "release"
else "compressed-tensors>=0.11.1a2"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
fms_mo/utils/dq_inf.py
Outdated
| raise ValueError("This quantization method is not supported for inferencing") | ||
|
|
||
|
|
||
| def load_inference_qconfig_file(model_args, fms_mo_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add type hint for return (I suppose dict, please check)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also add type hint for input arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
fms_mo/utils/dq_inf.py
Outdated
| return qcfg | ||
|
|
||
|
|
||
| def update_qcfg_from_model_config(model_args, qcfg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add type hint on input arguments and return across all functions in this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
fms_mo/dq.py
Outdated
| logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") | ||
| qcfg = qconfig_init(recipe="dq", args=fms_mo_args) | ||
|
|
||
| quant_mode = check_quantization_setting(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quant_mode is boolean but it isn't clear from the name of this variable what its role is (to me, it implies a quantization strategy). I think do_quantization or run_quantization or (with opposite meaning) inference_only would be clearer names. Consider updating this name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to inference_only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected
fms_mo/utils/dq_inf.py
Outdated
| if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8: | ||
| raise ValueError("Only 8-bit FP weight quantization is supported") | ||
|
|
||
| if quant_config["kv_cache_scheme"] is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace empty if branch that uses pass with a check against is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected
| if quant_config is None: | ||
| return False | ||
|
|
||
| logger.info("Validating config settings") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if quant_config dict has a quant_method key. If not, raise error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
fms_mo/utils/dq_inf.py
Outdated
|
|
||
| return True | ||
|
|
||
| raise ValueError("This quantization method is not supported for inferencing") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only "compressed-tensors" quant_method is supported for inference-only run, correct? If so, update this error to be more specific about what quant_method options (a single one) are supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
fms_mo/utils/dq_inf.py
Outdated
| recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args | ||
| ) | ||
| else: | ||
| logger.info("qcfg file found, loading the qcfg file ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rephrase this message as:
logger.info(f"loading quantization configuration from {model_args.model_name_or_path + '/qcfg.json'}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| ) | ||
| else: | ||
| logger.info("qcfg file found, loading the qcfg file ") | ||
| qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it OK here if we use "/qcfg" instead of "/qcfg.json"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think "qcfg" alone 'll work.
fms_mo/utils/dq_inf.py
Outdated
| return qcfg | ||
|
|
||
|
|
||
| # def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
| if ".weight" in k: | ||
| key = k.split("weight")[0] | ||
| if key + "quantize_weight.scale" in keys: | ||
| weight, scale = to_fp8_scaled_perCh(v, emulate=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if missing, add warning (somewhere, could be outside this function) that only static weights per-channel is supported for conversion from FMS-MO to vLLM at this time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
… layers Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
Signed-off-by: Omobayode Fagbohungbe <omobayode.fagbohungbe@ibm.com>
| return False | ||
|
|
||
| logger.info("Validating config settings") | ||
| if "quant_method" in quant_config.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this nested if/raise can be further cleaned-up a bit, for example
if "quant_method" in quant_config: # NOTE "in dict.keys()" can be simplified as "in dict"
if quant_config["quant_method"] == "compressed-tensors":
if quant_config["format"] != "float-quantized":
<do something....>
raise Error1
raise Error2can be rewritten as
if "quant_method" not in quant_config:
raise Error2
if quant_config.get("quant_method", None) != "compressed-tensors":
raise Error1
if quant_config["format"] != "float-quantized":
<do something....>should give us fewer indents and possibly fewer line breaks => improves readability
Also @andrea-fasoli 's note was not addressed earlier, use dict.get(xxx, default) to avoid missing keys or additional checks
| raise ValueError("Only 8-bit FP weight quantization is supported") | ||
|
|
||
| if quant_config["kv_cache_scheme"] is not None: | ||
| if quant_config["kv_cache_scheme"]["type"] is not float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we may want to allow dq to run with NO kv compression as well, i.e. it does not always have to be 8bit.
| logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") | ||
| qcfg = qconfig_init(recipe="dq", args=fms_mo_args) | ||
|
|
||
| inference_only = check_quantization_setting(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"inference only" seems more appropriate to be a new "mode" in run_quant.quantize() rather than part of dq, we may want to separate "inference only" code here to an independent function/file and add to run_quant
|
|
||
| def swap_qbmm(model: nn.Module, qcfg: dict): | ||
| """Go through all model.named_modules(), try to create an equivalent | ||
| Qbmm layer to replace each of the existing linear Bmm layers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this func name and description seem inaccurate and may cause confusion. what this func does is to create and attach new Qbmm modules to a module where torch.matmul/torch.bmm is being used. Has nothing to do with "swap" and "replace Linear BMM layers"...
Description of the change
This PR enables fast inference loading functionality for Fp8 DQ. It also enables saving a fms_mo checkpoint in a format compatible with HF/VLLM. It also enables a function that converts a saved fms_mo checkpoint directly to HF/VLLM checkpoint format.
Related issues or PRs
How to verify the PR
Was the PR tested
yes
Checklist for passing CI/CD:
git commit -signoffor equivalenttox -e fixtox -e linttox -e spellchecktox -e unitNote: CI/CD performs unit tests on multiple versions of Python from a fresh install. There may be differences with your local environment and the test environment.