From d17ffd390325d484f71bfcfbf7bdcfe258d723fc Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 6 May 2025 12:11:57 -0700 Subject: [PATCH 01/15] Eval hf models using lm_eval --- torchao/_models/llama/eval_hf.py | 95 ++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 torchao/_models/llama/eval_hf.py diff --git a/torchao/_models/llama/eval_hf.py b/torchao/_models/llama/eval_hf.py new file mode 100644 index 0000000000..dc87ff93c6 --- /dev/null +++ b/torchao/_models/llama/eval_hf.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow + + +def quantize_model_and_save(model_id, quant_config, output_dir="results"): + """Quantize the model and save it to the output directory.""" + quantization_config = TorchAoConfig(quant_type=quant_config) + quantized_model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + quantized_model.save_pretrained(output_dir, safe_serialization=False) + tokenizer.save_pretrained(output_dir, safe_serialization=False) + return quantized_model, tokenizer + + +# Run lm_eval +# lm_eval --model hf --model_args pretrained=llama-fp8 --tasks hellaswag --device cuda:0 --batch_size 8 + +import subprocess + + +def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): + """Run the lm_eval command using subprocess.""" + command = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={model_dir}", + "--tasks", + f"{tasks}", + "--device", + f"{device}", + "--batch_size", + f"{batch_size}", + ] + subprocess.run(command, check=True) + + +# def push_to_hub(user_id, model_name, quant_recipe='float8dq'): +# """Push to hub""" +# save_to = f"{user_id}/{model_name}-{quant_recipe}" +# quantized_model.push_to_hub(save_to, safe_serialization=False) +# tokenizer.push_to_hub(save_to) + +# def prompt_testing(quantized_model, tokenizer): +# # Manual Testing +# prompt = "Hey, are you conscious? Can you talk to me?" +# messages = [ +# { +# "role": "system", +# "content": "", +# }, +# {"role": "user", "content": prompt}, +# ] +# templated_prompt = tokenizer.apply_chat_template( +# messages, +# tokenize=False, +# add_generation_prompt=True, +# ) +# print("Prompt:", prompt) +# print("Templated prompt:", templated_prompt) +# inputs = tokenizer( +# templated_prompt, +# return_tensors="pt", +# ).to("cuda") +# generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) +# output_text = tokenizer.batch_decode( +# generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +# ) +# print("Response:", output_text[0][len(prompt):]) + + +if __name__ == "__main__": + model_id = "meta-llama/Llama-3.1-8B" + model_name = model_id.split("/")[-1] + model_output_dir = f"quantized_model/{model_name}" + quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + quantized_model, tokenizer = quantize_model_and_save( + model_id, quant_config=quant_config, output_dir=model_output_dir + ) + run_lm_eval(model_output_dir) + # prompt_testing(quantized_model, tokenizer) From dc355d6844736ab1bcca7b5679fd917341b42fb6 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 6 May 2025 13:08:29 -0700 Subject: [PATCH 02/15] Throughput updates --- torchao/_models/llama/eval_hf.py | 89 +++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/torchao/_models/llama/eval_hf.py b/torchao/_models/llama/eval_hf.py index dc87ff93c6..bacc28ab58 100644 --- a/torchao/_models/llama/eval_hf.py +++ b/torchao/_models/llama/eval_hf.py @@ -49,12 +49,14 @@ def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): subprocess.run(command, check=True) +# TODO: Fix this # def push_to_hub(user_id, model_name, quant_recipe='float8dq'): # """Push to hub""" # save_to = f"{user_id}/{model_name}-{quant_recipe}" # quantized_model.push_to_hub(save_to, safe_serialization=False) # tokenizer.push_to_hub(save_to) +# TODO: Fix this # def prompt_testing(quantized_model, tokenizer): # # Manual Testing # prompt = "Hey, are you conscious? Can you talk to me?" @@ -83,6 +85,85 @@ def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): # print("Response:", output_text[0][len(prompt):]) +def model_throughput( + model, + tokenizer, + prompt="What are we having for dinner?", + max_new_tokens=10, + num_runs=5, + print_all_responses=False, +): + """ + Calculate model throughput in tokens per second. + + Args: + model: The model to evaluate + tokenizer: The tokenizer to use + prompt: The input prompt + max_new_tokens: Number of tokens to generate + num_runs: Number of runs to average over for more accurate measurement + print_all_responses: Whether to print responses from all runs or just the last one + + Returns: + float: Throughput in tokens per second + """ + import time + + import torch + + # Tokenize the prompt + inputs = tokenizer( + prompt, + return_tensors="pt", + ).to("cuda") + + # Warmup run + with torch.no_grad(): + _ = model.generate(**inputs, max_new_tokens=max_new_tokens) + + # Measure generation time over multiple runs + total_tokens = 0 + total_time = 0 + generated_ids = None + + for _ in range(num_runs): + # Start timing + torch.cuda.synchronize() + start_time = time.time() + + # Generate text + with torch.no_grad(): + generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) + + # End timing + torch.cuda.synchronize() + end_time = time.time() + + # Calculate tokens generated (excluding prompt tokens) + prompt_length = inputs.input_ids.shape[1] + total_length = generated_ids.shape[1] + new_tokens = total_length - prompt_length + + total_tokens += new_tokens + total_time += end_time - start_time + + # Calculate throughput + throughput = total_tokens / total_time + + # Get the output text for the last run + output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + print(f"Response: {output_text[0][len(prompt) :]}") + print(f"Throughput: {throughput:.2f} tokens/sec") + print( + f"Average generation time: {(total_time / num_runs) * 1000:.2f} ms for {max_new_tokens} tokens" + ) + + return throughput + + if __name__ == "__main__": model_id = "meta-llama/Llama-3.1-8B" model_name = model_id.split("/")[-1] @@ -91,5 +172,11 @@ def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): quantized_model, tokenizer = quantize_model_and_save( model_id, quant_config=quant_config, output_dir=model_output_dir ) - run_lm_eval(model_output_dir) + # run_lm_eval(model_output_dir) + model_throughput( + quantized_model, + tokenizer, + prompt="What are we having for dinner?", + max_new_tokens=128, + ) # prompt_testing(quantized_model, tokenizer) From 8c7583b40fe58dbf08513535914b9ba314eacb50 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 6 May 2025 13:27:46 -0700 Subject: [PATCH 03/15] Updates --- torchao/_models/README.md | 5 ++++ torchao/_models/{llama => }/eval_hf.py | 39 ++------------------------ 2 files changed, 8 insertions(+), 36 deletions(-) rename torchao/_models/{llama => }/eval_hf.py (78%) diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 074adf884c..859ffd80fa 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,3 +1,8 @@ +# TODO: Add info for _models here, what is in the repo + +# TODO: Add llama- eval_hf.py, reproducable code, and a eval table here + + ## SAM2 sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc diff --git a/torchao/_models/llama/eval_hf.py b/torchao/_models/eval_hf.py similarity index 78% rename from torchao/_models/llama/eval_hf.py rename to torchao/_models/eval_hf.py index bacc28ab58..e6a37b1637 100644 --- a/torchao/_models/llama/eval_hf.py +++ b/torchao/_models/eval_hf.py @@ -9,6 +9,9 @@ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow +# TODO: Make it optional lm_eval dependency +# Add a check for lm_eval installed + def quantize_model_and_save(model_id, quant_config, output_dir="results"): """Quantize the model and save it to the output directory.""" @@ -49,42 +52,6 @@ def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): subprocess.run(command, check=True) -# TODO: Fix this -# def push_to_hub(user_id, model_name, quant_recipe='float8dq'): -# """Push to hub""" -# save_to = f"{user_id}/{model_name}-{quant_recipe}" -# quantized_model.push_to_hub(save_to, safe_serialization=False) -# tokenizer.push_to_hub(save_to) - -# TODO: Fix this -# def prompt_testing(quantized_model, tokenizer): -# # Manual Testing -# prompt = "Hey, are you conscious? Can you talk to me?" -# messages = [ -# { -# "role": "system", -# "content": "", -# }, -# {"role": "user", "content": prompt}, -# ] -# templated_prompt = tokenizer.apply_chat_template( -# messages, -# tokenize=False, -# add_generation_prompt=True, -# ) -# print("Prompt:", prompt) -# print("Templated prompt:", templated_prompt) -# inputs = tokenizer( -# templated_prompt, -# return_tensors="pt", -# ).to("cuda") -# generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) -# output_text = tokenizer.batch_decode( -# generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False -# ) -# print("Response:", output_text[0][len(prompt):]) - - def model_throughput( model, tokenizer, From 2768c17e9f0085bf09eae88f7118e10f7b9d32dd Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 6 May 2025 21:34:38 -0700 Subject: [PATCH 04/15] Add sh script --- .../_models/eval_hf_models.py | 120 ++++++++++++++---- benchmarks/_models/eval_hf_models.sh | 24 ++++ 2 files changed, 121 insertions(+), 23 deletions(-) rename torchao/_models/eval_hf.py => benchmarks/_models/eval_hf_models.py (58%) create mode 100644 benchmarks/_models/eval_hf_models.sh diff --git a/torchao/_models/eval_hf.py b/benchmarks/_models/eval_hf_models.py similarity index 58% rename from torchao/_models/eval_hf.py rename to benchmarks/_models/eval_hf_models.py index e6a37b1637..dc24f8a84b 100644 --- a/torchao/_models/eval_hf.py +++ b/benchmarks/_models/eval_hf_models.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import argparse +import subprocess +import time + import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig -from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow - -# TODO: Make it optional lm_eval dependency -# Add a check for lm_eval installed +from benchmarks.microbenchmarks.utils import string_to_config +from torchao.quantization.utils import _lm_eval_available def quantize_model_and_save(model_id, quant_config, output_dir="results"): @@ -28,12 +30,6 @@ def quantize_model_and_save(model_id, quant_config, output_dir="results"): return quantized_model, tokenizer -# Run lm_eval -# lm_eval --model hf --model_args pretrained=llama-fp8 --tasks hellaswag --device cuda:0 --batch_size 8 - -import subprocess - - def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): """Run the lm_eval command using subprocess.""" command = [ @@ -58,7 +54,6 @@ def model_throughput( prompt="What are we having for dinner?", max_new_tokens=10, num_runs=5, - print_all_responses=False, ): """ Calculate model throughput in tokens per second. @@ -74,10 +69,6 @@ def model_throughput( Returns: float: Throughput in tokens per second """ - import time - - import torch - # Tokenize the prompt inputs = tokenizer( prompt, @@ -131,19 +122,102 @@ def model_throughput( return throughput -if __name__ == "__main__": - model_id = "meta-llama/Llama-3.1-8B" +def run( + model_id, + quantization, + tasks, + device, + batch_size, + prompt, + max_new_tokens, + num_runs, + model_output_dir, +): + print(f"Running model {model_id} with quantization {quantization}") model_name = model_id.split("/")[-1] - model_output_dir = f"quantized_model/{model_name}" - quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + model_output_dir = f"quantized_model/{model_name}-{quantization}" + quant_config = string_to_config(quantization, None) quantized_model, tokenizer = quantize_model_and_save( model_id, quant_config=quant_config, output_dir=model_output_dir ) - # run_lm_eval(model_output_dir) + run_lm_eval(model_output_dir, tasks=tasks, device=device, batch_size=batch_size) model_throughput( quantized_model, tokenizer, - prompt="What are we having for dinner?", - max_new_tokens=128, + prompt=prompt, + max_new_tokens=max_new_tokens, + num_runs=num_runs, + ) + # TODO: Add memory usage measurement + + +if __name__ == "__main__": + if not _lm_eval_available: + print( + "lm_eval is required to run this script. Please install it using pip install lm-eval." + ) + exit(0) + + # Set up argument parser + parser = argparse.ArgumentParser( + description="Quantize a model and evaluate its throughput." + ) + parser.add_argument( + "--model_id", + type=str, + default="meta-llama/Llama-3.1-8B", + help="The model ID to use.", + ) + parser.add_argument( + "--quantization", + type=str, + default="float8wo", + help="The quantization method to use.", + ) + parser.add_argument( + "--tasks", type=str, default="hellaswag", help="Tasks to run in lm_eval." + ) + parser.add_argument( + "--device", type=str, default="cuda:0", help="Device to run the model on." + ) + parser.add_argument( + "--batch_size", type=int, default=8, help="Batch size for lm_eval." + ) + parser.add_argument( + "--prompt", + type=str, + default="What are we having for dinner?", + help="Prompt for model throughput evaluation.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=10, + help="Max new tokens to generate for throughput evaluation.", + ) + parser.add_argument( + "--num_runs", + type=int, + default=5, + help="Number of runs to average over for throughput evaluation.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="quantized_models", + help="Output directory for quantized model.", + ) + args = parser.parse_args() + + # Use parsed arguments + run( + model_id=args.model_id, + quantization=args.quantization, + tasks=args.tasks, + device=args.device, + batch_size=args.batch_size, + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + num_runs=args.num_runs, + model_output_dir=args.output_dir, ) - # prompt_testing(quantized_model, tokenizer) diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh new file mode 100644 index 0000000000..c2945144d4 --- /dev/null +++ b/benchmarks/_models/eval_hf_models.sh @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# For llama3.1-8B + +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4dq +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8adq-int4w-symm + + +# For llama3.2-3B + +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4dq +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8adq-int4w-symm From fa24a4ffba58709a59c196a16de1b648789aa59a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 7 May 2025 11:08:07 -0700 Subject: [PATCH 05/15] Add sh script to regenerate numbers --- benchmarks/_models/eval_hf_models.sh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh index c2945144d4..f7a18667a1 100644 --- a/benchmarks/_models/eval_hf_models.sh +++ b/benchmarks/_models/eval_hf_models.sh @@ -8,17 +8,15 @@ python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4dq -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8adq-int4w-symm +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-32 +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq # For llama3.2-3B python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4dq -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8adq-int4w-symm +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-32 +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq From 984e5184f06b63a2ac10337dea040e0099a25f8e Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 7 May 2025 11:18:56 -0700 Subject: [PATCH 06/15] Add sh script to regenerate numbers wikitext --- benchmarks/_models/eval_hf_models.sh | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh index f7a18667a1..3925d2eecd 100644 --- a/benchmarks/_models/eval_hf_models.sh +++ b/benchmarks/_models/eval_hf_models.sh @@ -3,20 +3,19 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. - # For llama3.1-8B -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-32 -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-32 --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext # For llama3.2-3B -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-32 -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-32 --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext From 79f8af6a23debfe402dc4a0dc214ceeaed37a523 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 8 May 2025 21:41:36 -0700 Subject: [PATCH 07/15] Updated code for model size --- benchmarks/_models/eval_hf_models.py | 58 ++++++++++++++++++++++++---- benchmarks/_models/eval_hf_models.sh | 30 ++++++++------ benchmarks/microbenchmarks/utils.py | 10 +++++ torchao/_models/README.md | 36 ++++++++++++++++- 4 files changed, 113 insertions(+), 21 deletions(-) diff --git a/benchmarks/_models/eval_hf_models.py b/benchmarks/_models/eval_hf_models.py index dc24f8a84b..c3cb3fa510 100644 --- a/benchmarks/_models/eval_hf_models.py +++ b/benchmarks/_models/eval_hf_models.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import itertools import subprocess import time @@ -12,12 +13,17 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from benchmarks.microbenchmarks.utils import string_to_config +from torchao.quantization import * # noqa: F401, F403 from torchao.quantization.utils import _lm_eval_available def quantize_model_and_save(model_id, quant_config, output_dir="results"): """Quantize the model and save it to the output directory.""" - quantization_config = TorchAoConfig(quant_type=quant_config) + print("Quantizing model with config: ", quant_config) + if quant_config is None: + quantization_config = None + else: + quantization_config = TorchAoConfig(quant_type=quant_config) quantized_model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", @@ -30,8 +36,9 @@ def quantize_model_and_save(model_id, quant_config, output_dir="results"): return quantized_model, tokenizer -def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): +def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): """Run the lm_eval command using subprocess.""" + tasks_str = ",".join(tasks_list) command = [ "lm_eval", "--model", @@ -39,7 +46,7 @@ def run_lm_eval(model_dir, tasks="hellaswag", device="cuda:0", batch_size=8): "--model_args", f"pretrained={model_dir}", "--tasks", - f"{tasks}", + f"{tasks_str}", "--device", f"{device}", "--batch_size", @@ -122,6 +129,36 @@ def model_throughput( return throughput +def get_model_size_in_bytes(model, ignore_embeddings=False): + """ + Returns the model size in bytes. The option to ignore embeddings + is useful for models with disproportionately large embeddings compared + to other model parameters that get quantized/sparsified. + """ + + def flat_size(tensor): + if hasattr(tensor, "__tensor_flatten__"): + size = 0 + # 0th element is a list of attributes that + # hold tensors + for attr_name in tensor.__tensor_flatten__()[0]: + sub_tensor = getattr(tensor, attr_name) + size += flat_size(sub_tensor) + return size + else: + return tensor.numel() * tensor.element_size() + + model_size = 0 + for _, child in model.named_children(): + if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): + for p in itertools.chain( + child.parameters(recurse=False), child.buffers(recurse=False) + ): + model_size += flat_size(p) + model_size += get_model_size_in_bytes(child, ignore_embeddings) + return model_size + + def run( model_id, quantization, @@ -140,7 +177,9 @@ def run( quantized_model, tokenizer = quantize_model_and_save( model_id, quant_config=quant_config, output_dir=model_output_dir ) - run_lm_eval(model_output_dir, tasks=tasks, device=device, batch_size=batch_size) + run_lm_eval( + model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size + ) model_throughput( quantized_model, tokenizer, @@ -148,7 +187,8 @@ def run( max_new_tokens=max_new_tokens, num_runs=num_runs, ) - # TODO: Add memory usage measurement + model_size = get_model_size_in_bytes(quantized_model, ignore_embeddings=True) / 1e9 + print(f"Model size: {model_size:.2f} GB") if __name__ == "__main__": @@ -171,11 +211,15 @@ def run( parser.add_argument( "--quantization", type=str, - default="float8wo", + default=None, help="The quantization method to use.", ) parser.add_argument( - "--tasks", type=str, default="hellaswag", help="Tasks to run in lm_eval." + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( "--device", type=str, default="cuda:0", help="Device to run the model on." diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh index 3925d2eecd..14feef7505 100644 --- a/benchmarks/_models/eval_hf_models.sh +++ b/benchmarks/_models/eval_hf_models.sh @@ -3,19 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -# For llama3.1-8B -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-32 --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext + +# For llama3.1-8B +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128 --tasks wikitext hellaswag # For llama3.2-3B - -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-32 --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128 --tasks wikitext hellaswag diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index f591ec3669..f6e450226b 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -18,6 +18,7 @@ Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, @@ -287,6 +288,15 @@ def string_to_config( else: granularity = PerTensor() return Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + if "gemlitewo" in quantization: + group_size = int(quantization.split("-")[1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + return GemliteUIntXWeightOnlyConfig(group_size=group_size) return None diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 859ffd80fa..fc4c495602 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,7 +1,39 @@ -# TODO: Add info for _models here, what is in the repo +# TODO: Add info for _models here -# TODO: Add llama- eval_hf.py, reproducable code, and a eval table here +# Eval on Llama 3.1 8B and Llama 3.2 3B +We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: + +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Throughput (tokens/sec)| Model Size (GB) | +|---------------|------------------------|-------|--------|----------------|------------------------|-------------------| +| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 44.95 | 15.01 | +| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 33.95 | 4.76 | +| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 28.65 | 8.04 | +| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 4.75 | 8.03 | +| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 17.84 | 8.03 | +| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 10.96 | 8.04 | +| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 10.63 | 8.03 | +| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 14.42 | 4.76 | + +| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 53.08 | 6.43 | +| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 36.36 | 2.29 | +| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 36.30 | 3.61 | +| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 5.45 | 3.61 | +| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 28.95 | 3.61 | +| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 12.56 | 3.61 | +| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 12.21 | 3.61 | +| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 16.52 | 2.29 | + +To generate the above results run: +``` +sh benchmarks/_models/eval_hf_models.sh +``` + +To run lm-eval for a different hf-model with AO quantization technique, run: +``` +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +``` +Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. ## SAM2 sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc From c8a591a473e7e08027cbfb692a29fdfde75e6635 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 8 May 2025 22:24:07 -0700 Subject: [PATCH 08/15] Fix readme issues --- torchao/_models/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/_models/README.md b/torchao/_models/README.md index fc4c495602..3157844a3b 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,6 +1,6 @@ # TODO: Add info for _models here -# Eval on Llama 3.1 8B and Llama 3.2 3B +## Eval on Llama 3.1 8B and Llama 3.2 3B We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: @@ -15,6 +15,8 @@ We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace mod | Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 10.63 | 8.03 | | Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 14.42 | 4.76 | +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Throughput (tokens/sec)| Model Size (GB) | +|---------------|------------------------|-------|--------|----------------|------------------------|-------------------| | Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 53.08 | 6.43 | | Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 36.36 | 2.29 | | Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 36.30 | 3.61 | From 116bf96b342c21d37929d10f71fb7e97084f75ef Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 13 May 2025 16:19:36 -0700 Subject: [PATCH 09/15] Remove throughput --- benchmarks/_models/eval_hf_models.py | 96 ++-------------------------- 1 file changed, 7 insertions(+), 89 deletions(-) diff --git a/benchmarks/_models/eval_hf_models.py b/benchmarks/_models/eval_hf_models.py index c3cb3fa510..2bca1fe5f0 100644 --- a/benchmarks/_models/eval_hf_models.py +++ b/benchmarks/_models/eval_hf_models.py @@ -7,7 +7,6 @@ import argparse import itertools import subprocess -import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig @@ -55,80 +54,6 @@ def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size subprocess.run(command, check=True) -def model_throughput( - model, - tokenizer, - prompt="What are we having for dinner?", - max_new_tokens=10, - num_runs=5, -): - """ - Calculate model throughput in tokens per second. - - Args: - model: The model to evaluate - tokenizer: The tokenizer to use - prompt: The input prompt - max_new_tokens: Number of tokens to generate - num_runs: Number of runs to average over for more accurate measurement - print_all_responses: Whether to print responses from all runs or just the last one - - Returns: - float: Throughput in tokens per second - """ - # Tokenize the prompt - inputs = tokenizer( - prompt, - return_tensors="pt", - ).to("cuda") - - # Warmup run - with torch.no_grad(): - _ = model.generate(**inputs, max_new_tokens=max_new_tokens) - - # Measure generation time over multiple runs - total_tokens = 0 - total_time = 0 - generated_ids = None - - for _ in range(num_runs): - # Start timing - torch.cuda.synchronize() - start_time = time.time() - - # Generate text - with torch.no_grad(): - generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) - - # End timing - torch.cuda.synchronize() - end_time = time.time() - - # Calculate tokens generated (excluding prompt tokens) - prompt_length = inputs.input_ids.shape[1] - total_length = generated_ids.shape[1] - new_tokens = total_length - prompt_length - - total_tokens += new_tokens - total_time += end_time - start_time - - # Calculate throughput - throughput = total_tokens / total_time - - # Get the output text for the last run - output_text = tokenizer.batch_decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - print(f"Response: {output_text[0][len(prompt) :]}") - print(f"Throughput: {throughput:.2f} tokens/sec") - print( - f"Average generation time: {(total_time / num_runs) * 1000:.2f} ms for {max_new_tokens} tokens" - ) - - return throughput - - def get_model_size_in_bytes(model, ignore_embeddings=False): """ Returns the model size in bytes. The option to ignore embeddings @@ -165,9 +90,6 @@ def run( tasks, device, batch_size, - prompt, - max_new_tokens, - num_runs, model_output_dir, ): print(f"Running model {model_id} with quantization {quantization}") @@ -177,16 +99,15 @@ def run( quantized_model, tokenizer = quantize_model_and_save( model_id, quant_config=quant_config, output_dir=model_output_dir ) + print("Compiling model ....") + quantized_model = torch.compile( + quantized_model, + mode="reduce-overhead", + fullgraph=True, + ) run_lm_eval( model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size ) - model_throughput( - quantized_model, - tokenizer, - prompt=prompt, - max_new_tokens=max_new_tokens, - num_runs=num_runs, - ) model_size = get_model_size_in_bytes(quantized_model, ignore_embeddings=True) / 1e9 print(f"Model size: {model_size:.2f} GB") @@ -225,7 +146,7 @@ def run( "--device", type=str, default="cuda:0", help="Device to run the model on." ) parser.add_argument( - "--batch_size", type=int, default=8, help="Batch size for lm_eval." + "--batch_size", type=int, default=1, help="Batch size for lm_eval." ) parser.add_argument( "--prompt", @@ -260,8 +181,5 @@ def run( tasks=args.tasks, device=args.device, batch_size=args.batch_size, - prompt=args.prompt, - max_new_tokens=args.max_new_tokens, - num_runs=args.num_runs, model_output_dir=args.output_dir, ) From 1b27430ffcce425ee5ecfc2f6a28e08ece31cfe7 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 21 May 2025 15:47:12 -0700 Subject: [PATCH 10/15] Update readme --- torchao/_models/README.md | 42 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 3157844a3b..33385e4f23 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -4,27 +4,27 @@ We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: -| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Throughput (tokens/sec)| Model Size (GB) | -|---------------|------------------------|-------|--------|----------------|------------------------|-------------------| -| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 44.95 | 15.01 | -| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 33.95 | 4.76 | -| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 28.65 | 8.04 | -| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 4.75 | 8.03 | -| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 17.84 | 8.03 | -| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 10.96 | 8.04 | -| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 10.63 | 8.03 | -| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 14.42 | 4.76 | - -| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Throughput (tokens/sec)| Model Size (GB) | -|---------------|------------------------|-------|--------|----------------|------------------------|-------------------| -| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 53.08 | 6.43 | -| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 36.36 | 2.29 | -| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 36.30 | 3.61 | -| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 5.45 | 3.61 | -| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 28.95 | 3.61 | -| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 12.56 | 3.61 | -| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 12.21 | 3.61 | -| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 16.52 | 2.29 | +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) | +|------------|---------------------------|-------|--------|----------------|-------------------| +| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 15.01 | +| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 4.76 | +| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 8.04 | +| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 8.03 | +| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 8.03 | +| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 8.04 | +| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 8.03 | +| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 4.76 | + +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) | +|------------|---------------------------|-------|--------|----------------|-------------------| +| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 6.43 | +| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 2.29 | +| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 3.61 | +| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 3.61 | +| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 3.61 | +| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 3.61 | +| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 3.61 | +| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 2.29 | To generate the above results run: ``` From 0a3a2bbc38530fa7eea162bfbc2b5088d3d52f00 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 21 May 2025 16:05:40 -0700 Subject: [PATCH 11/15] Update readme --- torchao/_models/README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 33385e4f23..ab044355b1 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,6 +1,4 @@ -# TODO: Add info for _models here - -## Eval on Llama 3.1 8B and Llama 3.2 3B +# Eval on Llama 3.1 8B and Llama 3.2 3B We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: From e14c1869cdef8d39690c4c3a1c95adf8cdf1ee43 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 21 May 2025 16:06:31 -0700 Subject: [PATCH 12/15] Update readme --- torchao/_models/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/_models/README.md b/torchao/_models/README.md index ab044355b1..300f1ed7d3 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,4 +1,6 @@ -# Eval on Llama 3.1 8B and Llama 3.2 3B +# LLAMA + +## Eval on Llama 3.1 8B and Llama 3.2 3B We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: @@ -35,7 +37,7 @@ python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B - ``` Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. -## SAM2 +# SAM2 sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc It includes From 459cee64c51ac5cde1f5556366fe7d610d2df7c4 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 9 Jun 2025 13:16:04 -0700 Subject: [PATCH 13/15] Gemlite kernel --- benchmarks/_models/eval_hf_models.sh | 6 ++++-- benchmarks/microbenchmarks/README.md | 1 + benchmarks/microbenchmarks/utils.py | 6 ++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh index 14feef7505..d71d16e422 100644 --- a/benchmarks/_models/eval_hf_models.sh +++ b/benchmarks/_models/eval_hf_models.sh @@ -13,7 +13,8 @@ python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B - python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-4 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-8 --tasks wikitext hellaswag # For llama3.2-3B @@ -24,4 +25,5 @@ python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B - python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-4 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-8 --tasks wikitext hellaswag diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index f300bbab23..49868d88cd 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -71,6 +71,7 @@ Currently, quantization string is in same format as the one being passed in llam - `int8wo`: 8-bit weight-only quantization - `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size - `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ +- `gemlitewo-{group_size}-{bit_width}`: 4 or 8 bit integer quantization and utilizes the gemlite triton kernel ### Model Types - `linear`: Simple linear layer diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index b41a9268cb..1fb575dec8 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -293,14 +293,16 @@ def string_to_config( granularity = PerTensor() return Float8DynamicActivationFloat8WeightConfig(granularity=granularity) if "gemlitewo" in quantization: - group_size = int(quantization.split("-")[1]) + params = quantization.split("-") + group_size = int(params[1]) if len(params) > 1 else 64 assert group_size in [ 32, 64, 128, 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - return GemliteUIntXWeightOnlyConfig(group_size=group_size) + bit_width = int(params[2]) if len(params) > 2 else 4 + return GemliteUIntXWeightOnlyConfig(group_size=group_size, bit_width=bit_width) return None From 128d89854bad74fbb4793aca71de582081c7e5a2 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 23 Jun 2025 11:03:16 -0700 Subject: [PATCH 14/15] Updates --- benchmarks/microbenchmarks/utils.py | 10 ++++++++-- third_party/cutlass | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 1fb575dec8..40bce5c33d 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -294,14 +294,20 @@ def string_to_config( return Float8DynamicActivationFloat8WeightConfig(granularity=granularity) if "gemlitewo" in quantization: params = quantization.split("-") - group_size = int(params[1]) if len(params) > 1 else 64 + bit_width = int(params[1]) if len(params) > 1 else 4 + group_size = ( + int(params[2]) + if len(params) > 2 and bit_width == 4 + else None + if bit_width == 8 + else 64 + ) assert group_size in [ 32, 64, 128, 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - bit_width = int(params[2]) if len(params) > 2 else 4 return GemliteUIntXWeightOnlyConfig(group_size=group_size, bit_width=bit_width) return None diff --git a/third_party/cutlass b/third_party/cutlass index ad7b2f5e84..e94e888df3 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e +Subproject commit e94e888df3551224738bfa505787b515eae8352f From a1e28e3bc4ae6739fd391c3a2af9db2d50c3acfd Mon Sep 17 00:00:00 2001 From: jainapurva Date: Sun, 29 Jun 2025 12:15:57 -0700 Subject: [PATCH 15/15] Update readme --- benchmarks/microbenchmarks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index 49868d88cd..eb9564d7d7 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -71,7 +71,7 @@ Currently, quantization string is in same format as the one being passed in llam - `int8wo`: 8-bit weight-only quantization - `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size - `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ -- `gemlitewo-{group_size}-{bit_width}`: 4 or 8 bit integer quantization and utilizes the gemlite triton kernel +- `gemlitewo-{bit_width}-{group_size}`: 4 or 8 bit integer quantization and utilizes the gemlite triton kernel ### Model Types - `linear`: Simple linear layer