Skip to content

Commit 7a846d5

Browse files
authored
Gemlite generate.py fix (#2372)
* fix get_plain() with FMA mode * update * fix in_features/out_feature meta-data mismatch * update gemlite slice test * add packing_bitwidth support * add packing_bitwidth support and cleanup * update default gemlite layout * cleanup * fix symmetric use-case and relax _same_meta_data * _copy() meta data * fix (4,) in autoquant * Add dynamic mode in gemlite layout * mode explanation Signed-off-by: mobicham <hicham@mobiuslabs.com> * use weights_only instead of static * generate fix Signed-off-by: mobicham <hicham@mobiuslabs.com> * remove set_packing_bitwidth --------- Signed-off-by: mobicham <hicham@mobiuslabs.com>
1 parent 32599be commit 7a846d5

File tree

2 files changed

+20
-37
lines changed

2 files changed

+20
-37
lines changed

torchao/_models/llama/benchmarks.sh

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,13 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
9797
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured
9898

9999
# gemlite benchmarks
100-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt
101-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt
102-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt
103-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt
104-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt
105-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt
100+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-4-64-wo --write_result benchmark_results.txt
101+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-4-128-wo --write_result benchmark_results.txt
102+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-None-dq --write_result benchmark_results.txt
106103

107-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 32
108-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 32
109-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 32
110-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 32
111-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 32
112-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 32
104+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-4-64-wo --write_result benchmark_results.txt --batch_size 32
105+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-4-128-wo --write_result benchmark_results.txt --batch_size 32
106+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-None-dq --write_result benchmark_results.txt --batch_size 32
113107

114108
# 2:4 sparse model
115109
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4

torchao/_models/llama/generate.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
244244

245245
def _load_model(checkpoint_path, device, precision):
246246
checkpoint = torch.load(
247-
str(checkpoint_path), mmap=True, weights_only=True, map_location=device
247+
str(checkpoint_path), mmap=True, weights_only=True, map_location="cpu"
248248
)
249249
if "model" in checkpoint and "stories" in str(checkpoint_path):
250250
checkpoint = checkpoint["model"]
@@ -366,34 +366,24 @@ def ffn_or_attn_only(mod, fqn):
366366
import os
367367
import pwd
368368

369-
from gemlite.core import GemLiteLinearTriton
369+
import gemlite
370+
371+
gemlite.set_autotune("max")
372+
config_file = f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
370373

371374
_quant_args = quantization.split("-")
372-
bit_width = int(_quant_args[-2])
373-
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
374-
try:
375-
packing_bitwidth = int(_quant_args[-3])
376-
except:
377-
# if only 2 inputs found, use default value
378-
packing_bitwidth = 32
375+
bit_width = int(_quant_args[1])
376+
group_size = None if _quant_args[2] == "None" else int(_quant_args[2])
377+
mode = "dynamic" if _quant_args[3] == "dq" else "weight_only"
379378

380379
quantize_(
381380
model,
382-
gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth),
381+
gemlite_uintx_weight_only(
382+
bit_width=bit_width, group_size=group_size, mode=mode
383+
),
383384
)
384385

385-
# try to load gemlite kernel config
386-
try:
387-
GemLiteLinearTriton.load_config(
388-
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
389-
)
390-
print(
391-
f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
392-
)
393-
except:
394-
print(
395-
f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
396-
)
386+
gemlite.load_config(config_file)
397387

398388
print("running gemlite warmup")
399389
generate(
@@ -405,9 +395,8 @@ def ffn_or_attn_only(mod, fqn):
405395
temperature=temperature,
406396
top_k=top_k,
407397
)
408-
GemLiteLinearTriton.cache_config(
409-
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
410-
)
398+
gemlite.cache_config(config_file)
399+
411400
if "int8wo" in quantization:
412401
quantize_(model, int8_weight_only())
413402
if "int8dq" in quantization:

0 commit comments

Comments
 (0)