Skip to content

Commit b0ff696

Browse files
committed
inference roofline
1 parent 3475aed commit b0ff696

File tree

2 files changed

+466
-0
lines changed

2 files changed

+466
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
import torchao
3+
import torch.nn as nn
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from torch.profiler import profile, record_function, ProfilerActivity
7+
from torchao.quantization.quant_api import quantize_, float8_dynamic_activation_float8_weight
8+
import copy
9+
from utils import (
10+
get_name_to_shapes_iter,
11+
)
12+
import tqdm
13+
14+
# Set the device (GPU if available)
15+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16+
17+
class ToyLinearModel(torch.nn.Module):
18+
def __init__(self, m=64, n=32, k=64):
19+
super().__init__()
20+
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.bfloat16)
21+
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.bfloat16)
22+
23+
def example_inputs(self, batch_size=1, dtype=torch.float, device="cuda"):
24+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
25+
26+
def forward(self, x):
27+
x = self.linear1(x)
28+
x = self.linear2(x)
29+
return x
30+
31+
# Function to benchmark model evaluation with profiling
32+
def benchmark_model_with_profiling(model, input_data, dtype):
33+
print('Model before quantization: ', model)
34+
quantize_(model, float8_dynamic_activation_float8_weight())
35+
print('Model quantized: ', model)
36+
model.eval() # Set the model to evaluation mode
37+
# input_data = torch.randn(input_size, device=device)
38+
39+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
40+
with record_function("model_inference"):
41+
with torch.no_grad():
42+
_ = model(*input_data)
43+
44+
# Return the profiler output
45+
return prof
46+
47+
name_to_shapes = get_name_to_shapes_iter("square", None, None, None)
48+
49+
50+
51+
# Set the data types
52+
float8_dtype = torch.float8_e4m3fn # Replace with the actual float8 dtype from TorchAO
53+
bf16_dtype = torch.bfloat16
54+
55+
# Dictionary to store performance data
56+
performance_data = {
57+
'Input Size': [],
58+
'float8 Kernel Times (ms)': [],
59+
'bf16 Kernel Times (ms)': []
60+
}
61+
62+
# Run benchmarks for each input size
63+
for idx, (name, (m, k, n)) in enumerate(tqdm.tqdm(name_to_shapes)):
64+
print(f"Profiling model with input size: {m, k, n}")
65+
66+
# Initialize the model with the specified dimensions
67+
model = ToyLinearModel().eval().to(device)
68+
example_inputs = model.example_inputs()
69+
model_bf16 = copy.deepcopy(model).to(device) # Copy the model to bf
70+
model_ref = copy.deepcopy(model).to(device) # Copy the model for quantization
71+
72+
73+
print('Model created: ', model)
74+
print('Example inputs: ', len(example_inputs), example_inputs[0].size())
75+
76+
# Profile float8 model evaluation
77+
prof_float8 = benchmark_model_with_profiling(model_ref, example_inputs, float8_dtype)
78+
prof_float8.export_chrome_trace(f"float8_model_{example_inputs[0].size()[0]}.json") # Save profiling details
79+
80+
# Profile bf16 model evaluation
81+
prof_bf16 = benchmark_model_with_profiling(model_bf16, example_inputs, bf16_dtype)
82+
prof_bf16.export_chrome_trace(f"bf16_model_{example_inputs[0].size()[0]}.json") # Save profiling details
83+
84+
print('Profiling keys: ', prof_float8.key_averages())
85+
# Calculate and store total GPU kernel times
86+
float8_kernel_time = sum([event.device_time for event in prof_float8.key_averages()])
87+
bf16_kernel_time = sum([event.device_time for event in prof_bf16.key_averages()])
88+
89+
performance_data['Input Size'].append(f"{example_inputs[0].size()[0]}")
90+
performance_data['float8 Kernel Times (ms)'].append(float8_kernel_time / 1000) # Convert from microseconds to milliseconds
91+
performance_data['bf16 Kernel Times (ms)'].append(bf16_kernel_time / 1000) # Convert from microseconds to milliseconds
92+
93+
print('Performance data: ', performance_data)
94+
95+
# Plotting the results
96+
plt.figure(figsize=(10, 6))
97+
plt.plot(performance_data['Input Size'], performance_data['float8 Kernel Times (ms)'], marker='o', label='float8')
98+
plt.plot(performance_data['Input Size'], performance_data['bf16 Kernel Times (ms)'], marker='s', label='bf16')
99+
plt.xlabel('Batch Size')
100+
plt.ylabel('Kernel Time (ms)')
101+
plt.title('Model Evaluation GPU Kernel Performance: float8 vs bf16')
102+
plt.legend()
103+
plt.grid(True)
104+
plt.show()

0 commit comments

Comments
 (0)