Skip to content

Commit 8dd5e33

Browse files
committed
inference roofline
1 parent 3475aed commit 8dd5e33

File tree

2 files changed

+498
-0
lines changed

2 files changed

+498
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch
2+
import torchao
3+
import torch.nn as nn
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from torch.profiler import profile, record_function, ProfilerActivity
7+
from torchao.quantization.quant_api import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
8+
import copy
9+
from utils import (
10+
get_name_to_shapes_iter,
11+
)
12+
import tqdm
13+
from tabulate import tabulate
14+
15+
# Set the device (GPU if available)
16+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17+
18+
class ToyLinearModel(torch.nn.Module):
19+
def __init__(self, m=64, n=32, k=64, dtype=torch.bfloat16):
20+
super().__init__()
21+
self.dtype = dtype
22+
self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype)
23+
24+
def example_inputs(self, m=1, device="cuda"):
25+
return (torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),)
26+
27+
def forward(self, x):
28+
x = self.linear1(x)
29+
return x
30+
31+
# Function to benchmark model evaluation with profiling
32+
def benchmark_model_with_profiling(model, input_data, dtype):
33+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
34+
# with record_function("model_inference"):
35+
for _ in range(5): # Run the model multiple times to warm up the cache
36+
with torch.no_grad():
37+
_ = model(*input_data)
38+
torch.cuda.synchronize()
39+
40+
# Return the profiler output
41+
return prof
42+
43+
44+
def get_gpu_kernel_times(profiler_chrome_trace, gpu_op_name):
45+
# Filter CUDA events
46+
event_data = [(event.key, event.device_time)
47+
for event in profiler_chrome_trace.key_averages()
48+
if event.device_type == torch.autograd.DeviceType.CUDA]
49+
50+
# Calculate overhead time and op time
51+
gpu_op_time, gpu_overhead_time = 0, 0
52+
for event in event_data:
53+
if gpu_op_name in event[0]:
54+
gpu_op_time += event[1]
55+
else:
56+
gpu_overhead_time += event[1]
57+
return gpu_op_time, gpu_overhead_time
58+
59+
def run_gemm_benchmarks(name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_dtype=torch.bfloat16, quantization_technique=float8_weight_only):
60+
# Dictionary to store performance data
61+
performance_data = {
62+
'Input Size': [],
63+
'float8 Op Kernel Times (ms)': [],
64+
'bf16 Op Kernel Times (ms)': [],
65+
'float8 Overhead Kernel Times (ms)': [],
66+
'bf16 Overhead Kernel Times (ms)': [],
67+
'float8 Total Kernel Times (ms)': [],
68+
'bf16 Total Kernel Times (ms)': [],
69+
}
70+
# Run benchmarks for each input size
71+
for idx, (name, (m, k, n)) in enumerate(tqdm.tqdm(name_to_shapes)):
72+
print(f"Profiling model with input size: {m, k, n} for quantization technique: {quantization_technique}, dtype: {float8_dtype} vs {other_dtype}")
73+
74+
# Initialize the model with the specified dimensions
75+
model = ToyLinearModel(m, k, n).eval().to(device)
76+
example_inputs = model.example_inputs(m)
77+
model_bf16 = copy.deepcopy(model).to(device) # Copy the model to bf
78+
model_ref = copy.deepcopy(model).to(device) # Copy the model for quantization
79+
quantize_(model_ref, quantization_technique()) # Quantize model to float8
80+
81+
# Profile float8 model evaluation
82+
prof_float8 = benchmark_model_with_profiling(model_ref, example_inputs, float8_dtype)
83+
prof_float8.export_chrome_trace(f"fp8_model_{example_inputs[0].size()[0]}.json") # Save profiling details
84+
85+
# Profile bf16 model evaluation
86+
prof_bf16 = benchmark_model_with_profiling(model_bf16, example_inputs, other_dtype)
87+
prof_bf16.export_chrome_trace(f"bf16_model_{example_inputs[0].size()[0]}.json") # Save profiling details
88+
89+
# Calculate and store GPU kernel times -> op time, overhead time
90+
float8_gpu_op_time, float8_gpu_overhead_time = get_gpu_kernel_times(prof_float8, 'gemm')
91+
bf16_gpu_op_time, bf16_gpu_overhead_time = get_gpu_kernel_times(prof_bf16, 'gemm')
92+
93+
# # Print profiling details
94+
# print(f"bfloat16_gpu_overhead_time: {bf16_gpu_overhead_time} gpu_op_time: {bf16_gpu_op_time}")
95+
# print(f"float8_gpu_overhead_time: {float8_gpu_overhead_time} float8_gpu_op_time: {float8_gpu_op_time}")
96+
97+
# Add the performance data to the dictionary
98+
# time/1000 -> Convert from microseconds to milliseconds
99+
performance_data['Input Size'].append(f"{tuple(example_inputs[0].shape)}")
100+
performance_data['float8 Total Kernel Times (ms)'].append((float8_gpu_op_time + float8_gpu_overhead_time) / 1000)
101+
performance_data['bf16 Total Kernel Times (ms)'].append((bf16_gpu_op_time + bf16_gpu_overhead_time) / 1000)
102+
performance_data['float8 Op Kernel Times (ms)'].append(float8_gpu_op_time / 1000)
103+
performance_data['bf16 Op Kernel Times (ms)'].append(bf16_gpu_op_time / 1000)
104+
performance_data['float8 Overhead Kernel Times (ms)'].append(float8_gpu_overhead_time / 1000)
105+
performance_data['bf16 Overhead Kernel Times (ms)'].append(bf16_gpu_overhead_time / 1000)
106+
107+
return performance_data
108+
109+
110+
def plot_performance_data(performance_data):
111+
# Plotting the results
112+
plt.figure(figsize=(10, 6))
113+
plt.plot(performance_data['Input Size'], performance_data['float8 Total Kernel Times (ms)'], marker='o', label='float8')
114+
plt.plot(performance_data['Input Size'], performance_data['bf16 Total Kernel Times (ms)'], marker='s', label='bf16')
115+
plt.xlabel('Batch Size')
116+
plt.ylabel('Kernel Time (ms)')
117+
plt.title('Model Evaluation GPU Kernel Performance: float8 vs bf16')
118+
plt.legend()
119+
plt.grid(True)
120+
plt.savefig('model_evaluation_gpu_kernel_performance.png')
121+
122+
123+
if __name__ == '__main__':
124+
# Set the data types
125+
name_to_shapes = get_name_to_shapes_iter("square", None, None, None)
126+
float8_dtype = torch.float8_e4m3fn # Change to the float8 dtype you want to use
127+
bf16_dtype = torch.bfloat16 # Change to the comparing dtype you want to use
128+
quantization_technique = float8_weight_only # Change to the quantization technique you want to use
129+
130+
performance_data = run_gemm_benchmarks(
131+
name_to_shapes=name_to_shapes,
132+
float8_dtype=float8_dtype,
133+
other_dtype=bf16_dtype,
134+
quantization_technique=quantization_technique
135+
)
136+
print('Performance data: \n', tabulate(performance_data, headers=performance_data.keys()))

0 commit comments

Comments
 (0)