Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions torchprime/experimental/benchmark/hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any

import torch
from transformers.models.llama import modeling_llama
from transformers.models.qwen3 import modeling_qwen3


def get_llama3_model(torch_dtype: torch.dtype):
"""Returns the Llama3.2 1B model."""
config = modeling_llama.LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=128000,
eos_token_id=128001,
head_dim=64,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=131072,
mlp_bias=False,
num_attention_heads=32,
num_hidden_layers=16,
num_key_value_heads=8,
rms_norm_eps=1e-05,
rope_scaling={
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
rope_theta=500000.0,
tie_word_embeddings=True,
use_cache=True,
vocab_size=128256,
_attn_implementation="eager",
)
model = modeling_llama.LlamaForCausalLM(config).to(torch_dtype)
return model


def get_qwen3_model(torch_dtype: torch.dtype):
"""Returns the Qwen3 1.7B model."""
config = modeling_qwen3.Qwen3Config(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=151643,
eos_token_id=151645,
head_dim=128,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=6144,
max_position_embeddings=40960,
max_window_layers=28,
num_attention_heads=16,
num_hidden_layers=28,
num_key_value_heads=8,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=1000000,
sliding_window=None,
tie_word_embeddings=True,
use_cache=True,
use_sliding_window=False,
vocab_size=151936,
_attn_implementation="eager",
)
model = modeling_qwen3.Qwen3ForCausalLM(config).to(torch_dtype)
return model


def get_model(model_name: str, dtype: torch.dtype) -> Any:
match model_name:
case "llama3.2-1B":
return get_llama3_model(dtype)
case "qwen3-1.7B":
return get_qwen3_model(dtype)
case _:
raise ValueError(f"Unsupported model: {model_name}")
108 changes: 108 additions & 0 deletions torchprime/experimental/benchmark/hf_models_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse
import os
import time
from typing import Any

import numpy as np
import torch
import torch_xla

from torchprime.experimental.benchmark.hf_model import get_model


def main(args):
# --- Configuration ---
dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
torch_dtype = dtype_map[args.dtype]

# It's good practice to define the device first.
device = torch_xla.device()

# Create the model on CPU first
model_cpu = get_model(args.model_name, torch_dtype)
config = model_cpu.config
model_cpu.eval() # Set to evaluation mode

# Move model to the XLA device.
model_tpu = model_cpu.to(device)

# Create dummy input_ids and move to the XLA device.
input_ids = torch.randint(
0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long
)
# Move inputs to the XLA device as well.
input_ids = input_ids.to(device)

# Preheat the cache.
print("Preheating...")
preheat_start_time = time.perf_counter()
with torch.no_grad():
_ = model_tpu(input_ids).logits
torch_xla.sync()
preheat_end_time = time.perf_counter()
preheat_time = preheat_end_time - preheat_start_time
print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms")

# Initial run (warm-up) to trigger XLA compilation
print("Warming up...")
warmup_start_time = time.perf_counter()
with torch.no_grad():
_ = model_tpu(input_ids).logits
torch_xla.sync()
warmup_end_time = time.perf_counter()
warmup_time = warmup_end_time - warmup_start_time

# Subsequent runs for measurement
print(f"Starting benchmark for {args.num_runs} runs...")
times = []
for i in range(args.num_runs):
start_time = time.perf_counter()
with torch.no_grad():
# The model forward pass is intentionally not assigned to a variable
# to measure only the execution time.
model_tpu(input_ids)

torch_xla.sync()
end_time = time.perf_counter()
times.append(end_time - start_time)
print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms")

# Print final performance results
print("\n--- Benchmark Results (Lazy Mode) ---")
print(f"Model: {args.model_name}, DType: {args.dtype}")
print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}")
print(f"Preheat time: {preheat_time * 1000:.2f} ms")
print(f"Warm-up time: {warmup_time * 1000:.2f} ms (includes compilation)")
print(f"Number of runs: {len(times)}")
print(f"Average latency: {np.mean(times) * 1000:.2f} ms")
print(f"Median latency: {np.median(times) * 1000:.2f} ms")
print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms")
print(f"Min latency: {np.min(times) * 1000:.2f} ms")
print(f"Max latency: {np.max(times) * 1000:.2f} ms")

# Add this line to wait for the TPU to finish and ensure a clean exit
torch_xla.sync()
print("Script finished and exited cleanly.")
os._exit(0) # <-- Use os._exit() instead of sys.exit()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Lazy Mode).")
parser.add_argument(
"--model_name",
type=str,
default="llama3.2-1B",
choices=["llama3.2-1B", "qwen3-1.7B"],
help="Model to benchmark (must match a config file name).",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float32"],
help="Data type for the model.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size.")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.")
parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.")
main(parser.parse_args())
109 changes: 109 additions & 0 deletions torchprime/experimental/benchmark/hf_models_forward_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import os
import time
from typing import Any

import numpy as np
import torch
import torch_xla

from torchprime.experimental.benchmark.hf_model import get_model


def main(args):
# --- Configuration ---
print("Running in PyTorch/XLA experimental eager mode.")
torch_xla.experimental.eager_mode(True)

dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
torch_dtype = dtype_map[args.dtype]

# It's good practice to define the device first.
device = torch_xla.device()

# Create the model on CPU first
model_cpu = get_model(args.model_name, torch_dtype)
config = model_cpu.config
model_cpu.eval() # Set to evaluation mode

# Move model to the XLA device.
model_tpu = model_cpu.to(device)

# Create dummy input_ids and move to the XLA device.
input_ids = torch.randint(
0, config.vocab_size, (args.batch_size, args.seq_len), dtype=torch.long
)
# Move inputs to the XLA device as well.
input_ids = input_ids.to(device)

# Preheat the cache.
print("Preheating...")
preheat_start_time = time.perf_counter()
with torch.no_grad():
_ = model_tpu(input_ids).logits
preheat_end_time = time.perf_counter()
preheat_time = preheat_end_time - preheat_start_time
print(f"PREHEAT WALL TIME: {preheat_time*1000:.4f} ms")

# Initial run (warm-up)
print("Warming up...")
warmup_start_time = time.perf_counter()
with torch.no_grad():
_ = model_tpu(input_ids).logits
warmup_end_time = time.perf_counter()
warmup_time = warmup_end_time - warmup_start_time

# Subsequent runs for measurement
print(f"Starting benchmark for {args.num_runs} runs...")
times = []
for i in range(args.num_runs):
start_time = time.perf_counter()
with torch.no_grad():
# The model forward pass is intentionally not assigned to a variable
# to measure only the execution time.
model_tpu(input_ids)

# Do we need this???
torch_xla.sync()

end_time = time.perf_counter()
times.append(end_time - start_time)
print(f"Run {i+1}/{args.num_runs}: {(end_time - start_time) * 1000:.2f} ms")

# Print final performance results
print("\n--- Benchmark Results (Eager Mode) ---")
print(f"Model: {args.model_name}, DType: {args.dtype}")
print(f"Batch Size: {args.batch_size}, Sequence Length: {args.seq_len}")
print(f"Preheat time: {preheat_time * 1000:.2f} ms")
print(f"Warm-up time: {warmup_time * 1000:.2f} ms")
print(f"Number of runs: {len(times)}")
print(f"Average latency: {np.mean(times) * 1000:.2f} ms")
print(f"Median latency: {np.median(times) * 1000:.2f} ms")
print(f"P90 latency: {np.percentile(times, 90) * 1000:.2f} ms")
print(f"Min latency: {np.min(times) * 1000:.2f} ms")
print(f"Max latency: {np.max(times) * 1000:.2f} ms")

print("Script finished and exited cleanly.")
os._exit(0) # <-- Use os._exit() instead of sys.exit()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark HF models on XLA (Eager Mode).")
parser.add_argument(
"--model_name",
type=str,
default="llama3.2-1B",
choices=["llama3.2-1B", "qwen3-1.7B"],
help="Model to benchmark (must match a config file name).",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float32"],
help="Data type for the model.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size.")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length.")
parser.add_argument("--num_runs", type=int, default=10, help="Number of benchmark runs.")
main(parser.parse_args())
Loading