- 
                Notifications
    You must be signed in to change notification settings 
- Fork 9
Benchmarking #392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Draft
      
      
            jack8558
  wants to merge
  6
  commits into
  main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
jackoh/benchmark
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Draft
                    Benchmarking #392
Changes from 2 commits
      Commits
    
    
            Show all changes
          
          
            6 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      a58e697
              
                Benchmarking on hf_model
              
              
                jack8558 f3d0e38
              
                some config fixes
              
              
                jack8558 a1a538a
              
                Assign variable for logits
              
              
                jack8558 5bf4f3e
              
                Use wait_device_ops instead of sync
              
              
                jack8558 83da39e
              
                Fix on forward
              
              
                jack8558 a387f2e
              
                Use torch_xla sync in preheat
              
              
                jack8558 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| from typing import Any | ||
|  | ||
| import torch | ||
| from transformers.models.llama import modeling_llama | ||
| from transformers.models.qwen3 import modeling_qwen3 | ||
|  | ||
|  | ||
| def get_llama3_model(torch_dtype: torch.dtype): | ||
| """Returns the Llama3.2 1B model.""" | ||
| config = modeling_llama.LlamaConfig( | ||
| attention_bias=False, | ||
| attention_dropout=0.0, | ||
| bos_token_id=128000, | ||
| eos_token_id=128001, | ||
| head_dim=64, | ||
| hidden_act="silu", | ||
| hidden_size=2048, | ||
| initializer_range=0.02, | ||
| intermediate_size=8192, | ||
| max_position_embeddings=131072, | ||
| mlp_bias=False, | ||
| num_attention_heads=32, | ||
| num_hidden_layers=16, | ||
| num_key_value_heads=8, | ||
| rms_norm_eps=1e-05, | ||
| rope_scaling={ | ||
| "factor": 32.0, | ||
| "high_freq_factor": 4.0, | ||
| "low_freq_factor": 1.0, | ||
| "original_max_position_embeddings": 8192, | ||
| "rope_type": "llama3", | ||
| }, | ||
| rope_theta=500000.0, | ||
| tie_word_embeddings=True, | ||
| use_cache=True, | ||
| vocab_size=128256, | ||
| _attn_implementation="eager", | ||
| ) | ||
| model = modeling_llama.LlamaForCausalLM(config).to(torch_dtype) | ||
| return model | ||
|  | ||
|  | ||
| def get_qwen3_model(torch_dtype: torch.dtype): | ||
| """Returns the Qwen3 1.7B model.""" | ||
| config = modeling_qwen3.Qwen3Config( | ||
| attention_bias=False, | ||
| attention_dropout=0.0, | ||
| bos_token_id=151643, | ||
| eos_token_id=151645, | ||
| head_dim=128, | ||
| hidden_act="silu", | ||
| hidden_size=2048, | ||
| initializer_range=0.02, | ||
| intermediate_size=6144, | ||
| max_position_embeddings=40960, | ||
| max_window_layers=28, | ||
| num_attention_heads=16, | ||
| num_hidden_layers=28, | ||
| num_key_value_heads=8, | ||
| rms_norm_eps=1e-06, | ||
| rope_scaling=None, | ||
| rope_theta=1000000, | ||
| sliding_window=None, | ||
| tie_word_embeddings=True, | ||
| use_cache=True, | ||
| use_sliding_window=False, | ||
| vocab_size=151936, | ||
| _attn_implementation="eager", | ||
| ) | ||
| model = modeling_qwen3.Qwen3ForCausalLM(config).to(torch_dtype) | ||
| return model | ||
|  | ||
|  | ||
| def get_model(model_name: str, dtype: torch.dtype) -> Any: | ||
| match model_name: | ||
| case "llama3.2-1B": | ||
| return get_llama3_model(dtype) | ||
| case "qwen3-1.7B": | ||
| return get_qwen3_model(dtype) | ||
| case _: | ||
| raise ValueError(f"Unsupported model: {model_name}") | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| import argparse | ||
| import os | ||
| import time | ||
| from typing import Any | ||
|  | ||
| import numpy as np | ||
| import torch | ||
| import torch_xla | ||
|  | ||
| from torchprime.experimental.benchmark.hf_model import get_model | ||
|  | ||
|  | ||
| def main(args): | ||
| # --- Configuration --- | ||
| dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} | ||
| torch_dtype = dtype_map[args.dtype] | ||
|  | ||
| # It's good practice to define the device first. | ||
| device = torch_xla.device() | ||
|  | ||
| # Create the model on CPU first | ||
| model_cpu = get_model(args.model_name, torch_dtype) | ||
| config = model_cpu.config | ||
| model_cpu.eval() # Set to evaluation mode | ||
|  | ||
| # Move model to the XLA device. | ||
| model_tpu = model_cpu.to(device) | ||
|  | ||
| # Create dummy input_ids and move to the XLA device. | ||
| input_ids = torch.randint( | ||
| 0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long | ||
| ) | ||
| # Move inputs to the XLA device as well. | ||
| input_ids = input_ids.to(device) | ||
|  | ||
| # Preheat the cache. | ||
| print("Preheating...") | ||
| preheat_start_time = time.perf_counter() | ||
| with torch.no_grad(): | ||
| _ = model_tpu(input_ids).logits | ||
| torch_xla.sync() | ||
| preheat_end_time = time.perf_counter() | ||
| preheat_time = preheat_end_time - preheat_start_time | ||
| print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") | ||
|  | ||
| # Initial run (warm-up) to trigger XLA compilation | ||
| print("Warming up...") | ||
| warmup_start_time = time.perf_counter() | ||
| with torch.no_grad(): | ||
| _ = model_tpu(input_ids).logits | ||
| torch_xla.sync() | ||
| warmup_end_time = time.perf_counter() | ||
| warmup_time = warmup_end_time - warmup_start_time | ||
|  | ||
| # Subsequent runs for measurement | ||
| print(f"Starting benchmark for {args.num_runs} runs...") | ||
| times = [] | ||
| for i in range(args.num_runs): | ||
| start_time = time.perf_counter() | ||
| with torch.no_grad(): | ||
| # The model forward pass is intentionally not assigned to a variable | ||
| # to measure only the execution time. | ||
|         
                  jack8558 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| model_tpu(input_ids) | ||
|  | ||
| torch_xla.sync() | ||
| end_time = time.perf_counter() | ||
| times.append(end_time - start_time) | ||
| print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") | ||
|  | ||
| # Print final performance results | ||
| print("\n--- Benchmark Results (Lazy Mode) ---") | ||
| print(f"Model: {args.model_name}, DType: {args.dtype}") | ||
| print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") | ||
| print(f"Preheat time: {preheat_time * 1000:.2f} ms") | ||
| print(f"Warm-up time: {warmup_time * 1000:.2f} ms (includes compilation)") | ||
| print(f"Number of runs: {len(times)}") | ||
| print(f"Average latency: {np.mean(times) * 1000:.2f} ms") | ||
| print(f"Median latency: {np.median(times) * 1000:.2f} ms") | ||
| print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms") | ||
| print(f"Min latency: {np.min(times) * 1000:.2f} ms") | ||
| print(f"Max latency: {np.max(times) * 1000:.2f} ms") | ||
|  | ||
| # Add this line to wait for the TPU to finish and ensure a clean exit | ||
| torch_xla.sync() | ||
| print("Script finished and exited cleanly.") | ||
| os._exit(0) # <-- Use os._exit() instead of sys.exit() | ||
|  | ||
|  | ||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Lazy Mode).") | ||
| parser.add_argument( | ||
| "--model_name", | ||
| type=str, | ||
| default="llama3.2-1B", | ||
| choices=["llama3.2-1B", "qwen3-1.7B"], | ||
| help="Model to benchmark (must match a config file name).", | ||
| ) | ||
| parser.add_argument( | ||
| "--dtype", | ||
| type=str, | ||
| default="bfloat16", | ||
| choices=["bfloat16", "float32"], | ||
| help="Data type for the model.", | ||
| ) | ||
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") | ||
| parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.") | ||
| parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.") | ||
| main(parser.parse_args()) | ||
        
          
          
            109 changes: 109 additions & 0 deletions
          
          109 
        
  torchprime/experimental/benchmark/hf_models_forward_eager.py
  
  
      
      
   
        
      
      
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| import argparse | ||
| import os | ||
| import time | ||
| from typing import Any | ||
|  | ||
| import numpy as np | ||
| import torch | ||
| import torch_xla | ||
|  | ||
| from torchprime.experimental.benchmark.hf_model import get_model | ||
|  | ||
|  | ||
| def main(args): | ||
| # --- Configuration --- | ||
| print("Running in PyTorch/XLA experimental eager mode.") | ||
| torch_xla.experimental.eager_mode(True) | ||
|  | ||
| dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} | ||
| torch_dtype = dtype_map[args.dtype] | ||
|  | ||
| # It's good practice to define the device first. | ||
| device = torch_xla.device() | ||
|  | ||
| # Create the model on CPU first | ||
| model_cpu = get_model(args.model_name, torch_dtype) | ||
| config = model_cpu.config | ||
| model_cpu.eval() # Set to evaluation mode | ||
|  | ||
| # Move model to the XLA device. | ||
| model_tpu = model_cpu.to(device) | ||
|  | ||
| # Create dummy input_ids and move to the XLA device. | ||
| input_ids = torch.randint( | ||
| 0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long | ||
| ) | ||
| # Move inputs to the XLA device as well. | ||
| input_ids = input_ids.to(device) | ||
|  | ||
| # Preheat the cache. | ||
| print("Preheating...") | ||
| preheat_start_time = time.perf_counter() | ||
| with torch.no_grad(): | ||
| _ = model_tpu(input_ids).logits | ||
| preheat_end_time = time.perf_counter() | ||
| preheat_time = preheat_end_time - preheat_start_time | ||
| print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms") | ||
|  | ||
| # Initial run (warm-up) | ||
| print("Warming up...") | ||
| warmup_start_time = time.perf_counter() | ||
| with torch.no_grad(): | ||
| _ = model_tpu(input_ids).logits | ||
| warmup_end_time = time.perf_counter() | ||
| warmup_time = warmup_end_time - warmup_start_time | ||
|  | ||
| # Subsequent runs for measurement | ||
| print(f"Starting benchmark for {args.num_runs} runs...") | ||
| times = [] | ||
| for i in range(args.num_runs): | ||
| start_time = time.perf_counter() | ||
| with torch.no_grad(): | ||
| # The model forward pass is intentionally not assigned to a variable | ||
| # to measure only the execution time. | ||
| model_tpu(input_ids) | ||
|  | ||
| # Do we need this??? | ||
| torch_xla.sync() | ||
|  | ||
| end_time = time.perf_counter() | ||
| times.append(end_time - start_time) | ||
| print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms") | ||
|  | ||
| # Print final performance results | ||
| print("\n--- Benchmark Results (Eager Mode) ---") | ||
| print(f"Model: {args.model_name}, DType: {args.dtype}") | ||
| print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}") | ||
| print(f"Preheat time: {preheat_time * 1000:.2f} ms") | ||
| print(f"Warm-up time: {warmup_time * 1000:.2f} ms") | ||
| print(f"Number of runs: {len(times)}") | ||
| print(f"Average latency: {np.mean(times) * 1000:.2f} ms") | ||
| print(f"Median latency: {np.median(times) * 1000:.2f} ms") | ||
| print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms") | ||
| print(f"Min latency: {np.min(times) * 1000:.2f} ms") | ||
| print(f"Max latency: {np.max(times) * 1000:.2f} ms") | ||
|  | ||
| print("Script finished and exited cleanly.") | ||
| os._exit(0) # <-- Use os._exit() instead of sys.exit() | ||
|  | ||
|  | ||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Eager Mode).") | ||
| parser.add_argument( | ||
| "--model_name", | ||
| type=str, | ||
| default="llama3.2-1B", | ||
| choices=["llama3.2-1B", "qwen3-1.7B"], | ||
| help="Model to benchmark (must match a config file name).", | ||
| ) | ||
| parser.add_argument( | ||
| "--dtype", | ||
| type=str, | ||
| default="bfloat16", | ||
| choices=["bfloat16", "float32"], | ||
| help="Data type for the model.", | ||
| ) | ||
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size.") | ||
| parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.") | ||
| parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.") | ||
| main(parser.parse_args()) | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.