|
| 1 | +import torch |
| 2 | +import torchao |
| 3 | +import torch.nn as nn |
| 4 | +import numpy as np |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from torch.profiler import profile, record_function, ProfilerActivity |
| 7 | +from torchao.quantization.quant_api import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only |
| 8 | +import copy |
| 9 | +from utils import ( |
| 10 | + get_name_to_shapes_iter, |
| 11 | +) |
| 12 | +import tqdm |
| 13 | +from tabulate import tabulate |
| 14 | + |
| 15 | +# Set the device (GPU if available) |
| 16 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 17 | + |
| 18 | +class ToyLinearModel(torch.nn.Module): |
| 19 | + def __init__(self, m=64, n=32, k=64, dtype=torch.bfloat16): |
| 20 | + super().__init__() |
| 21 | + self.dtype = dtype |
| 22 | + self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) |
| 23 | + |
| 24 | + def example_inputs(self, m=1, device="cuda"): |
| 25 | + return (torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),) |
| 26 | + |
| 27 | + def forward(self, x): |
| 28 | + x = self.linear1(x) |
| 29 | + return x |
| 30 | + |
| 31 | +# Function to benchmark model evaluation with profiling |
| 32 | +def benchmark_model_with_profiling(model, input_data, dtype): |
| 33 | + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: |
| 34 | + # with record_function("model_inference"): |
| 35 | + for _ in range(5): # Run the model multiple times to warm up the cache |
| 36 | + with torch.no_grad(): |
| 37 | + _ = model(*input_data) |
| 38 | + torch.cuda.synchronize() |
| 39 | + |
| 40 | + # Return the profiler output |
| 41 | + return prof |
| 42 | + |
| 43 | + |
| 44 | +def get_gpu_kernel_times(profiler_chrome_trace, gpu_op_name): |
| 45 | + # Filter CUDA events |
| 46 | + event_data = [(event.key, event.device_time) |
| 47 | + for event in profiler_chrome_trace.key_averages() |
| 48 | + if event.device_type == torch.autograd.DeviceType.CUDA] |
| 49 | + |
| 50 | + # Calculate overhead time and op time |
| 51 | + gpu_op_time, gpu_overhead_time = 0, 0 |
| 52 | + for event in event_data: |
| 53 | + if gpu_op_name in event[0]: |
| 54 | + gpu_op_time += event[1] |
| 55 | + else: |
| 56 | + gpu_overhead_time += event[1] |
| 57 | + return gpu_op_time, gpu_overhead_time |
| 58 | + |
| 59 | +def run_gemm_benchmarks(name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_dtype=torch.bfloat16, quantization_technique=float8_weight_only): |
| 60 | + # Dictionary to store performance data |
| 61 | + performance_data = { |
| 62 | + 'Input Size': [], |
| 63 | + 'float8 Op Kernel Times (ms)': [], |
| 64 | + 'bf16 Op Kernel Times (ms)': [], |
| 65 | + 'float8 Overhead Kernel Times (ms)': [], |
| 66 | + 'bf16 Overhead Kernel Times (ms)': [], |
| 67 | + 'float8 Total Kernel Times (ms)': [], |
| 68 | + 'bf16 Total Kernel Times (ms)': [], |
| 69 | + } |
| 70 | + # Run benchmarks for each input size |
| 71 | + for idx, (name, (m, k, n)) in enumerate(tqdm.tqdm(name_to_shapes)): |
| 72 | + print(f"Profiling model with input size: {m, k, n} for quantization technique: {quantization_technique}, dtype: {float8_dtype} vs {other_dtype}") |
| 73 | + |
| 74 | + # Initialize the model with the specified dimensions |
| 75 | + model = ToyLinearModel(m, k, n).eval().to(device) |
| 76 | + example_inputs = model.example_inputs(m) |
| 77 | + model_bf16 = copy.deepcopy(model).to(device) # Copy the model to bf |
| 78 | + model_ref = copy.deepcopy(model).to(device) # Copy the model for quantization |
| 79 | + quantize_(model_ref, quantization_technique()) # Quantize model to float8 |
| 80 | + |
| 81 | + # Profile float8 model evaluation |
| 82 | + prof_float8 = benchmark_model_with_profiling(model_ref, example_inputs, float8_dtype) |
| 83 | + prof_float8.export_chrome_trace(f"fp8_model_{example_inputs[0].size()[0]}.json") # Save profiling details |
| 84 | + |
| 85 | + # Profile bf16 model evaluation |
| 86 | + prof_bf16 = benchmark_model_with_profiling(model_bf16, example_inputs, other_dtype) |
| 87 | + prof_bf16.export_chrome_trace(f"bf16_model_{example_inputs[0].size()[0]}.json") # Save profiling details |
| 88 | + |
| 89 | + # Calculate and store GPU kernel times -> op time, overhead time |
| 90 | + float8_gpu_op_time, float8_gpu_overhead_time = get_gpu_kernel_times(prof_float8, 'gemm') |
| 91 | + bf16_gpu_op_time, bf16_gpu_overhead_time = get_gpu_kernel_times(prof_bf16, 'gemm') |
| 92 | + |
| 93 | + # # Print profiling details |
| 94 | + # print(f"bfloat16_gpu_overhead_time: {bf16_gpu_overhead_time} gpu_op_time: {bf16_gpu_op_time}") |
| 95 | + # print(f"float8_gpu_overhead_time: {float8_gpu_overhead_time} float8_gpu_op_time: {float8_gpu_op_time}") |
| 96 | + |
| 97 | + # Add the performance data to the dictionary |
| 98 | + # time/1000 -> Convert from microseconds to milliseconds |
| 99 | + performance_data['Input Size'].append(f"{tuple(example_inputs[0].shape)}") |
| 100 | + performance_data['float8 Total Kernel Times (ms)'].append((float8_gpu_op_time + float8_gpu_overhead_time) / 1000) |
| 101 | + performance_data['bf16 Total Kernel Times (ms)'].append((bf16_gpu_op_time + bf16_gpu_overhead_time) / 1000) |
| 102 | + performance_data['float8 Op Kernel Times (ms)'].append(float8_gpu_op_time / 1000) |
| 103 | + performance_data['bf16 Op Kernel Times (ms)'].append(bf16_gpu_op_time / 1000) |
| 104 | + performance_data['float8 Overhead Kernel Times (ms)'].append(float8_gpu_overhead_time / 1000) |
| 105 | + performance_data['bf16 Overhead Kernel Times (ms)'].append(bf16_gpu_overhead_time / 1000) |
| 106 | + |
| 107 | + return performance_data |
| 108 | + |
| 109 | + |
| 110 | +def plot_performance_data(performance_data): |
| 111 | + # Plotting the results |
| 112 | + plt.figure(figsize=(10, 6)) |
| 113 | + plt.plot(performance_data['Input Size'], performance_data['float8 Total Kernel Times (ms)'], marker='o', label='float8') |
| 114 | + plt.plot(performance_data['Input Size'], performance_data['bf16 Total Kernel Times (ms)'], marker='s', label='bf16') |
| 115 | + plt.xlabel('Batch Size') |
| 116 | + plt.ylabel('Kernel Time (ms)') |
| 117 | + plt.title('Model Evaluation GPU Kernel Performance: float8 vs bf16') |
| 118 | + plt.legend() |
| 119 | + plt.grid(True) |
| 120 | + plt.savefig('model_evaluation_gpu_kernel_performance.png') |
| 121 | + |
| 122 | + |
| 123 | +if __name__ == '__main__': |
| 124 | + # Set the data types |
| 125 | + name_to_shapes = get_name_to_shapes_iter("square", None, None, None) |
| 126 | + float8_dtype = torch.float8_e4m3fn # Change to the float8 dtype you want to use |
| 127 | + bf16_dtype = torch.bfloat16 # Change to the comparing dtype you want to use |
| 128 | + quantization_technique = float8_weight_only # Change to the quantization technique you want to use |
| 129 | + |
| 130 | + performance_data = run_gemm_benchmarks( |
| 131 | + name_to_shapes=name_to_shapes, |
| 132 | + float8_dtype=float8_dtype, |
| 133 | + other_dtype=bf16_dtype, |
| 134 | + quantization_technique=quantization_technique |
| 135 | + ) |
| 136 | + print('Performance data: \n', tabulate(performance_data, headers=performance_data.keys())) |
0 commit comments