-
Notifications
You must be signed in to change notification settings - Fork 16
feat: fast loading and saving functionality #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
b137e0c
e9874ef
0b5d68a
b458d18
adb7f38
31dd8c7
4878ba1
6217699
a2ae168
aca818a
fbdf19f
d3e7c61
7510770
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -391,12 +391,20 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): | |
| # For nn.Linear | ||
| elif isinstance(module, nn.Linear): | ||
| if module.__class__ != nn.Linear: | ||
| logger.warning( | ||
| f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." | ||
| "Please make sure it doesn't wrap BN and activ func." | ||
| "Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']." | ||
| ) | ||
| if available_packages["compressed_tensors"]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this block is incorrect because if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can incorporate the 3 nested ifs into a single check: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you set it up this way, add a comment to explain when this if clause is triggered because it is not immediate to understand There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not yet fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed now |
||
| # Third Party | ||
| import compressed_tensors | ||
|
|
||
| if isinstance( | ||
| module, compressed_tensors.linear.compressed_linear.CompressedLinear | ||
| ): | ||
| pass | ||
| else: | ||
| logger.warning( | ||
| f"{curr_full_name} {type(module)} seems to be a wrapper of Linear." | ||
| "Please make sure it doesn't wrap BN and activ func. Otherwise" | ||
| "please create an equivalent Linear wrapper and change qcfg['mapping']." | ||
| ) | ||
| QLin = mapping.get(nn.Linear, None) | ||
| if QLin is None: | ||
| if verbose: | ||
|
|
@@ -571,6 +579,42 @@ def has_quantized_module(model): | |
| return any(isinstance(m, quantized_modules) for m in model.modules()) | ||
|
|
||
|
|
||
| def swap_qbmm(model: nn.Module, qcfg: dict): | ||
| """Go through all model.named_modules(), try to create an equivalent | ||
| Qbmm layer to replace each of the existing linear Bmm layers. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this func name and description seem inaccurate and may cause confusion. what this func does is to create and attach new Qbmm modules to a module where torch.matmul/torch.bmm is being used. Has nothing to do with "swap" and "replace Linear BMM layers"... |
||
|
|
||
| Args: | ||
| model (nn.Module): input model to be "prepared" | ||
| qcfg (dict): quant config | ||
|
|
||
| Returns: updated model is returned with the Qbmm added | ||
|
|
||
| """ | ||
|
|
||
| # Local | ||
| from fms_mo.modules import QBmm | ||
|
|
||
| qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"]["which2patch_contextmanager"] | ||
| isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm" | ||
| for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items(): | ||
| mod_bmm_happened = model.get_submodule(mod_name) | ||
| for whichQBmm, ln in enumerate(line_nums, start=1): | ||
| nbits = qcfg[f"nbits_bmm{whichQBmm}"] | ||
| newQBmm = QBmm( | ||
| num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits, | ||
| num_bits_m2=nbits, | ||
| qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"], | ||
| qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"], | ||
| m1_unidirectional=(whichQBmm == 2), | ||
| m1_bounded=(whichQBmm == 2), # see Note 5 | ||
| m2_unidirectional=False, | ||
| m2_bounded=False, | ||
| replaceBmm=isbmm, | ||
| qcfg=qcfg, | ||
| ) | ||
| setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm) | ||
|
|
||
|
|
||
| def qmodel_prep( | ||
| model, | ||
| dloader, | ||
|
|
@@ -657,6 +701,12 @@ def qmodel_prep( | |
| Returns: | ||
| nn.Module: quantized model ready for further PTQ/QAT | ||
| """ | ||
| if qcfg["fp8_inference"]: | ||
| if qcfg.get("QBmm"): | ||
| swap_qbmm(model, qcfg) | ||
|
|
||
| model = q_any_net_5(model, qcfg, verbose=False) | ||
| return model | ||
|
|
||
| sys.setrecursionlimit(4000) | ||
|
|
||
|
|
@@ -906,8 +956,10 @@ def qmodel_prep( | |
| model = torch.nn.parallel.DistributedDataParallel( | ||
| model, device_ids=DPorDDPdevices | ||
| ) | ||
|
|
||
| qconfig_save(qcfg, fname="qcfg.json") | ||
| if qcfg["output_folder"] is None: | ||
| qconfig_save(qcfg, fname="qcfg.json") | ||
| else: | ||
| qconfig_save(qcfg, fname=qcfg["output_folder"] + "/qcfg.json") | ||
| qcfg["tb_writer"] = tb_writer | ||
|
|
||
| logger.info(f"--- Quantized model --- \n{model}\n") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| { | ||
| "quantization_config": { | ||
| "config_groups": { | ||
| "group_0": { | ||
| "input_activations": { | ||
| "actorder": null, | ||
| "block_structure": null, | ||
| "dynamic": true, | ||
| "group_size": null, | ||
| "num_bits": 8, | ||
| "observer": null, | ||
| "observer_kwargs": {}, | ||
| "strategy": "token", | ||
| "symmetric": true, | ||
| "type": "float" | ||
| }, | ||
| "output_activations": null, | ||
| "targets": [ | ||
| "Linear" | ||
| ], | ||
| "weights": { | ||
| "actorder": null, | ||
| "block_structure": null, | ||
| "dynamic": false, | ||
| "group_size": null, | ||
| "num_bits": 8, | ||
| "observer": "minmax", | ||
| "observer_kwargs": {}, | ||
| "strategy": "channel", | ||
| "symmetric": true, | ||
| "type": "float" | ||
| } | ||
| } | ||
| }, | ||
| "format": "float-quantized", | ||
| "global_compression_ratio": null, | ||
| "ignore": [ | ||
| "lm_head" | ||
| ], | ||
| "kv_cache_scheme": null, | ||
| "quant_method": "compressed-tensors", | ||
| "quantization_status": "compressed" | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quant_modeis boolean but it isn't clear from the name of this variable what its role is (to me, it implies a quantization strategy). I thinkdo_quantizationorrun_quantizationor (with opposite meaning)inference_onlywould be clearer names. Consider updating this name.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to inference_only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected