|
| 1 | +import torch |
| 2 | +import torchao |
| 3 | +import torch.nn as nn |
| 4 | +import numpy as np |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from torch.profiler import profile, record_function, ProfilerActivity |
| 7 | +from torchao.quantization.quant_api import quantize_, float8_dynamic_activation_float8_weight |
| 8 | +import copy |
| 9 | +from utils import ( |
| 10 | + get_name_to_shapes_iter, |
| 11 | +) |
| 12 | +import tqdm |
| 13 | + |
| 14 | +# Set the device (GPU if available) |
| 15 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 16 | + |
| 17 | +class ToyLinearModel(torch.nn.Module): |
| 18 | + def __init__(self, m=64, n=32, k=64): |
| 19 | + super().__init__() |
| 20 | + self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.bfloat16) |
| 21 | + self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.bfloat16) |
| 22 | + |
| 23 | + def example_inputs(self, batch_size=1, dtype=torch.float, device="cuda"): |
| 24 | + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) |
| 25 | + |
| 26 | + def forward(self, x): |
| 27 | + x = self.linear1(x) |
| 28 | + x = self.linear2(x) |
| 29 | + return x |
| 30 | + |
| 31 | +# Function to benchmark model evaluation with profiling |
| 32 | +def benchmark_model_with_profiling(model, input_data, dtype): |
| 33 | + print('Model before quantization: ', model) |
| 34 | + quantize_(model, float8_dynamic_activation_float8_weight()) |
| 35 | + print('Model quantized: ', model) |
| 36 | + model.eval() # Set the model to evaluation mode |
| 37 | + # input_data = torch.randn(input_size, device=device) |
| 38 | + |
| 39 | + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: |
| 40 | + with record_function("model_inference"): |
| 41 | + with torch.no_grad(): |
| 42 | + _ = model(*input_data) |
| 43 | + |
| 44 | + # Return the profiler output |
| 45 | + return prof |
| 46 | + |
| 47 | +name_to_shapes = get_name_to_shapes_iter("square", None, None, None) |
| 48 | + |
| 49 | + |
| 50 | + |
| 51 | +# Set the data types |
| 52 | +float8_dtype = torch.float8_e4m3fn # Replace with the actual float8 dtype from TorchAO |
| 53 | +bf16_dtype = torch.bfloat16 |
| 54 | + |
| 55 | +# Dictionary to store performance data |
| 56 | +performance_data = { |
| 57 | + 'Input Size': [], |
| 58 | + 'float8 Kernel Times (ms)': [], |
| 59 | + 'bf16 Kernel Times (ms)': [] |
| 60 | +} |
| 61 | + |
| 62 | +# Run benchmarks for each input size |
| 63 | +for idx, (name, (m, k, n)) in enumerate(tqdm.tqdm(name_to_shapes)): |
| 64 | + print(f"Profiling model with input size: {m, k, n}") |
| 65 | + |
| 66 | + # Initialize the model with the specified dimensions |
| 67 | + model = ToyLinearModel().eval().to(device) |
| 68 | + example_inputs = model.example_inputs() |
| 69 | + model_bf16 = copy.deepcopy(model).to(device) # Copy the model to bf |
| 70 | + model_ref = copy.deepcopy(model).to(device) # Copy the model for quantization |
| 71 | + |
| 72 | + |
| 73 | + print('Model created: ', model) |
| 74 | + print('Example inputs: ', len(example_inputs), example_inputs[0].size()) |
| 75 | + |
| 76 | + # Profile float8 model evaluation |
| 77 | + prof_float8 = benchmark_model_with_profiling(model_ref, example_inputs, float8_dtype) |
| 78 | + prof_float8.export_chrome_trace(f"float8_model_{example_inputs[0].size()[0]}.json") # Save profiling details |
| 79 | + |
| 80 | + # Profile bf16 model evaluation |
| 81 | + prof_bf16 = benchmark_model_with_profiling(model_bf16, example_inputs, bf16_dtype) |
| 82 | + prof_bf16.export_chrome_trace(f"bf16_model_{example_inputs[0].size()[0]}.json") # Save profiling details |
| 83 | + |
| 84 | + print('Profiling keys: ', prof_float8.key_averages()) |
| 85 | + # Calculate and store total GPU kernel times |
| 86 | + float8_kernel_time = sum([event.device_time for event in prof_float8.key_averages()]) |
| 87 | + bf16_kernel_time = sum([event.device_time for event in prof_bf16.key_averages()]) |
| 88 | + |
| 89 | + performance_data['Input Size'].append(f"{example_inputs[0].size()[0]}") |
| 90 | + performance_data['float8 Kernel Times (ms)'].append(float8_kernel_time / 1000) # Convert from microseconds to milliseconds |
| 91 | + performance_data['bf16 Kernel Times (ms)'].append(bf16_kernel_time / 1000) # Convert from microseconds to milliseconds |
| 92 | + |
| 93 | + print('Performance data: ', performance_data) |
| 94 | + |
| 95 | +# Plotting the results |
| 96 | +plt.figure(figsize=(10, 6)) |
| 97 | +plt.plot(performance_data['Input Size'], performance_data['float8 Kernel Times (ms)'], marker='o', label='float8') |
| 98 | +plt.plot(performance_data['Input Size'], performance_data['bf16 Kernel Times (ms)'], marker='s', label='bf16') |
| 99 | +plt.xlabel('Batch Size') |
| 100 | +plt.ylabel('Kernel Time (ms)') |
| 101 | +plt.title('Model Evaluation GPU Kernel Performance: float8 vs bf16') |
| 102 | +plt.legend() |
| 103 | +plt.grid(True) |
| 104 | +plt.show() |
0 commit comments