-
Notifications
You must be signed in to change notification settings - Fork 51
Add Qwix quantization + per-tensor KV cache quantization #205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Hongmin Fan <fanhongmin@google.com>
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: Jacob Platin <jacobplatin@google.com>
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: wwl2755-google <wenlongwang@google.com>
Manual CI tests passed: https://buildkite.com/tpu-commons/tpu-commons-ci/builds/565 Signed-off-by: bzgoogle <beinuoz@google.com>
Signed-off-by: Xiang Xu <xiangxu@google.com>
Signed-off-by: Jacob Platin <jacobplatin@google.com>
Signed-off-by: Xiang Xu <xiangxu@google.com>
Signed-off-by: Jacob Platin <jacobplatin@google.com>
Signed-off-by: Hongmin Fan <fanhongmin@google.com>
6de3c9b to
be643d1
Compare
f7561a2 to
66ad626
Compare
Signed-off-by: Jacob Platin <jacobplatin@google.com>
Signed-off-by: Jacob Platin <jacobplatin@google.com>
| num_kv_heads=kv_cache_spec. | ||
| num_kv_heads, # NOTE: we'll multiply by 2 in the function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you make this into single like? The current split is confusing and if the comment is the reason for split, we you add in a separate like above?
README.md
Outdated
| ] | ||
| ``` | ||
|
|
||
| You may also create a file that defines your own rules (e.g. `tpu_commons/models/jax/utils/quantization/quantize_all_modules_int8_wa.yaml`), where each entry under `rules` corresponds to a `qwix.QuantizationRule`. To pass this file (which is mutually exclusive with `quantization.dtype`), you can something similar to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not make a standard dir for quantization configs and ask user to add config to that dir. Can just pass the file name. We can store kv cache quantization config in the same file.
--additional_config='{"quantization": "int8_default.yaml"}'
And for a custom file
--additional_config='{"quantization": "int8_default_int8_kv.yaml"}'
Files can be in tpu_commons/models/jax/utils/quantization/configs/
We will likely have different configs checked in to this dir for different models/datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on making a standardized dir, but I think keeping KV cache / model quant separate makes more sense since we really only need the YAML for the Qwix rules, so I think it keeps the code / UX much cleaner if separate the two quants out (since the KV quant is only a dtype specification) -- can iterate more on this in the KV cache quant PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kv cache config also had multiple options - dtype, per_tensor/per_token/dimension_to_quantize_on, diff conf for key/value. We will only support a simple per_tensor int8 initially but keeping it in file keeps it flexible. Also one config for all quant is easier to read.
mitalisi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add details on perf and accuracy results for these techniques.
README.md
Outdated
|
|
||
| By default, we will use the following Qwix rules (with the given `dtype`), which will quantize attention weights-only and MLP with weights and activations: | ||
|
|
||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to add path to file here so that it remains up to date
|
|
||
| k_scale, v_scale = None, None | ||
| if k_scale_ref is not None: | ||
| k_scale = k_scale_ref[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO @jrplatin: do we want to keep this astype?
README.md
Outdated
| ] | ||
| ``` | ||
|
|
||
| You may also create a file that defines your own rules (e.g. `tpu_commons/models/jax/utils/quantization/quantize_all_modules_int8_wa.yaml`), where each entry under `rules` corresponds to a `qwix.QuantizationRule`. To pass this file (which is mutually exclusive with `quantization.dtype`), you can something similar to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on making a standardized dir, but I think keeping KV cache / model quant separate makes more sense since we really only need the YAML for the Qwix rules, so I think it keeps the code / UX much cleaner if separate the two quants out (since the KV quant is only a dtype specification) -- can iterate more on this in the KV cache quant PR
Signed-off-by: Jacob Platin <jacobplatin@google.com>
Description
Start with a short description of what the PR does and how this is a change from
the past.
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):