-
Notifications
You must be signed in to change notification settings - Fork 16
feat: fast loading and saving functionality #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
bayo-ibm
wants to merge
13
commits into
foundation-model-stack:main
Choose a base branch
from
bayo-ibm:fast_loading
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 3 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
b137e0c
feat: fast model inference
bayo-ibm e9874ef
Merge branch 'main' into fast_loading
bayo-ibm 0b5d68a
feat: enable fast loading and vllm format saving functionality in fms_mo
bayo-ibm b458d18
fix: updated the code to reflect PR update
bayo-ibm adb7f38
fix: re-naming of qcfg inference parameter
bayo-ibm 31dd8c7
fix: updated the inference file
bayo-ibm 4878ba1
fix: corrected the inference file
bayo-ibm 6217699
fix: corrected the lint error
bayo-ibm a2ae168
fix: corrected the ruff error
bayo-ibm aca818a
fix:minor edit on qmodel_prep
bayo-ibm fbdf19f
fix: type hinting arguments and returns
bayo-ibm d3e7c61
fix: improving argument hints and inferencing for models with skipped…
bayo-ibm 7510770
fix: correcting lint error
bayo-ibm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
| # Third Party | ||
| from torch import nn | ||
| import torch | ||
|
|
||
| import compressed_tensors | ||
| # Local | ||
| from fms_mo.calib import qmodel_calib | ||
| from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules | ||
|
|
@@ -391,12 +391,14 @@ | |
| # For nn.Linear | ||
| elif isinstance(module, nn.Linear): | ||
| if module.__class__ != nn.Linear: | ||
| logger.warning( | ||
| f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." | ||
| "Please make sure it doesn't wrap BN and activ func." | ||
| "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." | ||
| ) | ||
|
|
||
| if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear): | ||
| pass | ||
| else: | ||
| logger.warning( | ||
| f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." | ||
| "Please make sure it doesn't wrap BN and activ func." | ||
| "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." | ||
| ) | ||
| QLin = mapping.get(nn.Linear, None) | ||
| if QLin is None: | ||
| if verbose: | ||
|
|
@@ -570,6 +572,41 @@ | |
| """Check if model is already quantized - do not want to quantize twice if so""" | ||
| return any(isinstance(m, quantized_modules) for m in model.modules()) | ||
|
|
||
| def swap_qbmm(model: nn.Module, qcfg: dict): | ||
| """Go through all model.named_modules(), try to create an equivalent | ||
| Qbmm layer to replace each of the existing linear Bmm layers. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this func name and description seem inaccurate and may cause confusion. what this func does is to create and attach new Qbmm modules to a module where torch.matmul/torch.bmm is being used. Has nothing to do with "swap" and "replace Linear BMM layers"... |
||
|
|
||
| Args: | ||
| model (nn.Module): input model to be "prepared" | ||
| qcfg (dict): quant config | ||
|
|
||
| Returns: updated model is returned with the Qbmm added | ||
|
|
||
| """ | ||
|
|
||
| from fms_mo.modules import QBmm | ||
|
|
||
| qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][ | ||
| "which2patch_contextmanager" | ||
| ] | ||
| isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm" | ||
| for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items(): | ||
| mod_bmm_happened = model.get_submodule(mod_name) | ||
| for whichQBmm, ln in enumerate(line_nums, start=1): | ||
| nbits = qcfg[f"nbits_bmm{whichQBmm}"] | ||
| newQBmm = QBmm( | ||
| num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits, | ||
| num_bits_m2=nbits, | ||
| qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"], | ||
| qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"], | ||
| m1_unidirectional=(whichQBmm == 2), | ||
| m1_bounded=(whichQBmm == 2), # see Note 5 | ||
| m2_unidirectional=False, | ||
| m2_bounded=False, | ||
| replaceBmm=isbmm, | ||
| qcfg=qcfg, | ||
| ) | ||
| setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm) | ||
|
|
||
| def qmodel_prep( | ||
| model, | ||
|
|
@@ -582,6 +619,7 @@ | |
| Qcali=False, | ||
| dev=None, | ||
| use_dynamo=False, | ||
| mode=False, | ||
andrea-fasoli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| verbose=False, | ||
| **kwargs, | ||
| ): | ||
|
|
@@ -657,7 +695,14 @@ | |
| Returns: | ||
| nn.Module: quantized model ready for further PTQ/QAT | ||
| """ | ||
| if mode: | ||
|
|
||
| if qcfg.get("QBmm"): | ||
| swap_qbmm(model,qcfg) | ||
|
|
||
| model = q_any_net_5(model, qcfg, verbose = False) | ||
| return model | ||
|
|
||
| sys.setrecursionlimit(4000) | ||
|
|
||
| currDev = next(model.parameters()).device if dev is None else dev | ||
|
|
@@ -907,7 +952,7 @@ | |
| model, device_ids=DPorDDPdevices | ||
| ) | ||
|
|
||
| qconfig_save(qcfg, fname="qcfg.json") | ||
| qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json") | ||
| qcfg["tb_writer"] = tb_writer | ||
|
|
||
| logger.info(f"--- Quantized model --- \n{model}\n") | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
andrea-fasoli marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| { | ||
| "quantization_config": { | ||
| "config_groups": { | ||
| "group_0": { | ||
| "input_activations": { | ||
| "actorder": null, | ||
| "block_structure": null, | ||
| "dynamic": true, | ||
| "group_size": null, | ||
| "num_bits": 8, | ||
| "observer": null, | ||
| "observer_kwargs": {}, | ||
| "strategy": "token", | ||
| "symmetric": true, | ||
| "type": "float" | ||
| }, | ||
| "output_activations": null, | ||
| "targets": [ | ||
| "Linear" | ||
| ], | ||
| "weights": { | ||
| "actorder": null, | ||
| "block_structure": null, | ||
| "dynamic": false, | ||
| "group_size": null, | ||
| "num_bits": 8, | ||
| "observer": "minmax", | ||
| "observer_kwargs": {}, | ||
| "strategy": "channel", | ||
| "symmetric": true, | ||
| "type": "float" | ||
| } | ||
| } | ||
| }, | ||
| "format": "float-quantized", | ||
| "global_compression_ratio": null, | ||
| "ignore": [ | ||
| "lm_head" | ||
| ], | ||
| "kv_cache_scheme": null, | ||
| "quant_method": "compressed-tensors", | ||
| "quantization_status": "compressed" | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.