From bd0fa8fece67523d60bf3b0b978d50a5491b9dd1 Mon Sep 17 00:00:00 2001 From: cthi Date: Fri, 13 Jun 2025 14:09:54 -0700 Subject: [PATCH 1/4] Add initial version of TuningCache and scripts for heuristic + kernel Differential Revision: D75540999 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 4 +- .../fbgemm_gpu/quantize/tuning_cache.hpp | 292 +++++++++++++++ .../include/fbgemm_gpu/quantize/utils.h | 15 + .../quantize/common/scripts/gen_kernels.py | 337 ++++++++++++++++++ .../quantize/common/scripts/make_heuristic.py | 170 +++++++++ .../gen_ai/src/quantize/common/utils.cpp | 26 ++ 6 files changed, 843 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.hpp create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/gen_kernels.py create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/common/utils.cpp diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 1b631297cc..a9334b9c3d 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -11,7 +11,8 @@ glob_files_nohip(experimental_gen_ai_cpp_source_files_cpu src/attention/*.cpp src/coalesce/*.cpp - src/quantize/*.cpp) + src/quantize/*.cpp + src/quantize/common/*.cpp) glob_files_nohip(experimental_gen_ai_cpp_source_files_gpu src/attention/*.cu @@ -98,6 +99,7 @@ gpu_cpp_library( INCLUDE_DIRS ${fbgemm_sources_include_directories} ${CMAKE_CURRENT_SOURCE_DIR}/src/quantize + ${CMAKE_CURRENT_SOURCE_DIR}/src/quantize/common/include ${CMAKE_CURRENT_SOURCE_DIR}/src/kv_cache CPU_SRCS ${experimental_gen_ai_cpp_source_files_cpu} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.hpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.hpp new file mode 100644 index 0000000000..633a068fb4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/tuning_cache.hpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +/** + * Tuning cache for kernels. This class is responsible for evaluating new + * problem shapes (keyed by a string) against a predefined set of kernels, and + * caching the best kernel found. + */ +class TuningCache final { + public: + // kernelName should be unique for each type of kernel, as it is used to + // construct the filename. + explicit TuningCache(const std::string& kernelName) + : useCudaGraph_(std::getenv("FBGEMM_AUTOTUNE_USE_CUDA_GRAPH") != nullptr), + cacheDirectory_(getCacheDirectory()), + cacheFilename_(getCacheFilename(kernelName)), + detailedFilename_(getDetailedFilename(kernelName)) { + std::cout << "Using cache file at " << cacheFilename_ << std::endl; + + createCacheDirectory(); + loadCache(); + } + + TuningCache(const TuningCache&) = delete; + TuningCache& operator=(const TuningCache&) = delete; + TuningCache(TuningCache&&) = delete; + TuningCache& operator=(TuningCache&&) = delete; + + ~TuningCache() { + saveCache(); + } + + template + Kernel findBestKernelMaybeAutotune( + const std::string& cache_key, + const std::unordered_map& kernels, + Args&&... args) { + TORCH_CHECK(!kernels.empty(), "Kernels to tune over is empty."); + + auto it = cache_.find(cache_key); + if (it != cache_.end()) { + return getKernel(it->second, kernels); + } + + const auto start = std::chrono::high_resolution_clock::now(); + auto kernel_key = + findBestKernel(cache_key, kernels, std::forward(args)...); + const auto end = std::chrono::high_resolution_clock::now(); + const auto elapsed = + std::chrono::duration_cast(end - start); + std::cout << "Tuned " << kernel_key << " for key " << cache_key << " in " + << elapsed.count() << " ms." << std::endl; + + cache_.insert({cache_key, kernel_key}); + return getKernel(kernel_key, kernels); + } + + private: + template + Kernel getKernel( + const std::string& kernel_key, + const std::unordered_map& kernels) { + auto it = kernels.find(kernel_key); + TORCH_CHECK( + it != kernels.end(), + "Failed to find kernel keyed by " + kernel_key + + ". Consider deleting your fbgemm cache (~/.fbgemm)."); + return it->second; + } + + std::string getCacheDirectory() { + // If the environment variable is set, use that instead of the default + const char* cache_dir = std::getenv("FBGEMM_CACHE_DIR"); + if (cache_dir) { + return cache_dir; + } + + return std::string(std::getenv("HOME")) + "/" + + std::string(FBGEMM_CACHE_DIR); + } + + std::string getCacheFilename(const std::string& kernel_name) { + return getCacheDirectory() + "/" + kernel_name + ".txt"; + } + + std::string getDetailedFilename(const std::string& kernel_name) { + return getCacheDirectory() + "/" + kernel_name + "_detailed.txt"; + } + + bool cacheDirExists() { + return std::filesystem::exists(cacheDirectory_) && + std::filesystem::is_directory(cacheDirectory_); + } + + void createCacheDirectory() { + if (!cacheDirExists()) { + // Try to create the directory, multiple caches/processes may attempt + // this, and only one would succeed. + std::string error; + try { + if (std::filesystem::create_directory(cacheDirectory_)) { + return; + } + } catch (const std::filesystem::filesystem_error& e) { + error = e.what(); + } + + // If the directory still doesn't exist, error out + TORCH_CHECK( + cacheDirExists(), + "FBGEMM cache directory creation at " + cacheDirectory_ + + " failed: " + error); + } + } + + void loadCache() { + std::ifstream file(cacheFilename_); + if (!file.is_open()) { + // Create a new cache file if it doesn't exist + std::ofstream newFile(cacheFilename_); + newFile.close(); + } else { + std::string line; + while (std::getline(file, line)) { + size_t pos = line.find('='); + if (pos != std::string::npos) { + std::string key = line.substr(0, pos); + std::string value = line.substr(pos + 1); + cache_.insert_or_assign(key, value); + } + } + file.close(); + } + } + + void saveCache() { + // Only one rank needs to save the cache. This is fine as the cache + // should be largely equivalent across ranks. + if (at::cuda::current_device() != 0) { + return; + } + + std::ofstream file(cacheFilename_); + if (file.is_open()) { + for (const auto& pair : cache_) { + file << pair.first << "=" << pair.second << std::endl; + } + file.close(); + } + + if (!detailedTuningInfo_.empty()) { + std::ofstream detailed_file(detailedFilename_, std::ios_base::app); + if (detailed_file.is_open()) { + for (auto& [cache_key, kernels] : detailedTuningInfo_) { + // Sort for convenience in descending order of time_ms + std::sort( + kernels.begin(), kernels.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); + for (const auto& [kernel_name, time_ms] : kernels) { + detailed_file << cache_key << "," << kernel_name << "," << time_ms + << std::endl; + } + } + + detailed_file.close(); + } + } + } + + template + float benchmark(Kernel kernel, Args&&... args) { + // Warmup iteration + kernel(std::forward(args)...); + + // Estimate the number of iterations needed to run for 10 ms. This + // helps with stability for fast kernels. + start_.record(); + kernel(std::forward(args)...); + stop_.record(); + stop_.synchronize(); + const auto estimated_time_ms = start_.elapsed_time(stop_); + const int num_iters = std::max(1, int(10 / estimated_time_ms)); + + if (useCudaGraph_) { + at::cuda::CUDAGraph graph; + { + // CUDAGraph capture must happen on non-default stream + at::cuda::CUDAStream stream = at::cuda::getStreamFromPool(true); + at::cuda::CUDAStreamGuard streamGuard(stream); + + // For flexibility, we use cudaStreamCaptureModeRelaxed. + // - cudaStreamCaptureModeGlobal prevents other threads from calling + // certain CUDA APIs such as cudaEventQuery. This can conflict with + // things like ProcessGroupNCCL. + // - cudaStreamCaptureModeThreadLocal prevents CCA from freeing memory. + // Since CUDA graph is preferred for offline benchmark this should be + // fine. + graph.capture_begin({0, 0}, cudaStreamCaptureModeRelaxed); + for (int i = 0; i < num_iters; ++i) { + kernel(std::forward(args)...); + } + graph.capture_end(); + } + + // Time execution of graph + start_.record(); + graph.replay(); + stop_.record(); + stop_.synchronize(); + const auto graph_time_ms = start_.elapsed_time(stop_); + + return graph_time_ms / num_iters; + } else { + // Time execution of kernels + start_.record(); + for (int i = 0; i < num_iters; ++i) { + kernel(std::forward(args)...); + } + stop_.record(); + stop_.synchronize(); + const auto kernels_time_ms = start_.elapsed_time(stop_); + + return kernels_time_ms / num_iters; + } + } + + template + std::string findBestKernel( + const std::string& cache_key, + const std::unordered_map& kernels, + Args&&... args) { + std::string best_kernel; + float best_time = FLT_MAX; + + for (const auto& [kernel_name, kernel] : kernels) { + const float time = benchmark(kernel, std::forward(args)...); + if (time < best_time) { + best_time = time; + best_kernel = kernel_name; + } + if (std::getenv("FBGEMM_AUTOTUNE_COLLECT_STATS")) { + detailedTuningInfo_[cache_key].push_back({kernel_name, time}); + } + } + + return best_kernel; + } + + constexpr static std::string_view FBGEMM_CACHE_DIR = ".fbgemm"; + + at::cuda::CUDAEvent start_ = at::cuda::CUDAEvent(cudaEventDefault); + at::cuda::CUDAEvent stop_ = at::cuda::CUDAEvent(cudaEventDefault); + + // If FBGEMM_AUTOTUNE_USE_CUDA_GRAPH is set, use CUDA graph for benchmarking. + // CUDA graphs use a separate memory pool to do allocation in PyTorch + // CUDACachingAllocator to ensure the memory is valid throughout the graph, + // which can memory fragmentation (and higher chance of CUDA OOM). We can + // prefer to use CUDA graph for offline benchmarking, but not for online + // serving. + bool useCudaGraph_; + // Absolute path of the cache directory + std::string cacheDirectory_; + // Absolute path of the cache file for the kernel + std::string cacheFilename_; + // Absolute path of the detailed tuning info + std::string detailedFilename_; + // (cache key, best kernel) + std::unordered_map cache_; + // If FBGEMM_AUTOTUNE_COLLECT_STATS is set, we will log the timing for each + // kernel for each problem shape. This is useful to distill the best kernels + // into a smaller set. + std::unordered_map>> + detailedTuningInfo_; +}; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h new file mode 100644 index 0000000000..1422fc0e9c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h @@ -0,0 +1,15 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +namespace fbgemm_gpu { + +int nextPowerOf2(int n); + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/gen_kernels.py b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/gen_kernels.py new file mode 100644 index 0000000000..cb826a3c3b --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/gen_kernels.py @@ -0,0 +1,337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import subprocess + +COPYRIGHT = """/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + """ + +KERNEL_ID_TEMPLATES = { + "bf16bf16bf16_grouped": "bf16bf16bf16_grouped_{tM}_{tN}_{tK}_{cM}_{cN}_{cK}_{pong[0]}", + "f8f8bf16_rowwise": "f8f8bf16_rowwise_{tM}_{tN}_{tK}_{cM}_{cN}_{cK}_{arch}_{pong[0]}_{coop[0]}", +} + +bf16bf16bf16_grouped_decl_template = """ +at::Tensor {kernel_id}( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor {kernel_id}( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + """ + +f8f8bf16_rowwise_decl_template = """ +at::Tensor {kernel_id}( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + bool use_fast_accum = true, + std::optional bias = std::nullopt, + std::optional output = std::nullopt); + """ + +DECL_TEMPLATES = { + "bf16bf16bf16_grouped": bf16bf16bf16_grouped_decl_template, + "f8f8bf16_rowwise": f8f8bf16_rowwise_decl_template, +} + + +bf16bf16bf16_grouped_file_template = """ +at::Tensor {kernel_id}( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) {{ + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +}} + +at::Tensor {kernel_id}( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) {{ + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +}} +""" + +f8f8bf16_rowwise_file_template = """ +at::Tensor {kernel_id}( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + bool use_fast_accum = true, + std::optional bias = std::nullopt, + std::optional output = std::nullopt) {{ + // Dispatch this kernel to the correct underlying implementation. + return f8f8bf16_rowwise_wrapper<{tM}, {tN}, {tK}, {cM}, {cN}, {cK}, {arch}, {pong}, {coop}>( + XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); +}} +""" + +FILE_TEMPLATES = { + "bf16bf16bf16_grouped": bf16bf16bf16_grouped_file_template, + "f8f8bf16_rowwise": f8f8bf16_rowwise_file_template, +} + +bf16bf16bf16_grouped_kernel_map_template = """ +template +using Kernel_bf16bf16bf16_grouped = at::Tensor (*)( + InputType, + InputType, + at::Tensor, + std::optional, + std::optional); + +template +const std::unordered_map>& +get_bf16bf16bf16_grouped_kernels() {{ + static const std::unordered_map> kernels = {{ + {body} + }}; + return kernels; +}} +""" + +f8f8bf16_rowwise_kernel_map_template = """ +using Kernel_f8f8bf16_rowwise = at::Tensor (*)( + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + bool, + std::optional, + std::optional); + +const std::unordered_map& +get_f8f8bf16_rowwise_kernels(int arch) {{ + static const std::unordered_map kernelsSM90 = {{ + {bodySM90} + }}; + static const std::unordered_map kernelsSM100 = {{ + {bodySM100} + }}; + if (arch == 10) {{ + return kernelsSM100; + }} else {{ + return kernelsSM90; + }} +}} +""" + +ARCH_MAP_TEMPLATES = {"f8f8bf16_rowwise": f8f8bf16_rowwise_kernel_map_template} +MAP_TEMPLATES = {"bf16bf16bf16_grouped": bf16bf16bf16_grouped_kernel_map_template} + + +def gen_kernel_map_body(kernel_confs, arch=None): + return "\n".join( + [ + f'{{"{kernel_conf['kernel_id']}", {kernel_conf['kernel_id']}}},' + for kernel_conf in kernel_confs + if arch is None or kernel_conf["arch"] == arch + ] + ) + + +def gen_kernel_map(kernel_name, kernel_confs): + if kernel_name in MAP_TEMPLATES: + body = gen_kernel_map_body(kernel_confs) + return MAP_TEMPLATES[kernel_name].format(body=body) + + # ARCH_MAP_TEMPLATES + bodySM90 = gen_kernel_map_body(kernel_confs, 9) + bodySM100 = gen_kernel_map_body(kernel_confs, 10) + return ARCH_MAP_TEMPLATES[kernel_name].format( + bodySM90=bodySM90, bodySM100=bodySM100 + ) + + +def get_kernel_confs_nv(kernel_name): + # Change these as needed to explore different kernels configurations. + tiles = [ + (M, N, K) + for M in (64, 128, 256) + for N in ( + 16, + 32, + 64, + 128, + 256, + ) + for K in (128,) + ] + clusters = [(1, 1, 1), (2, 1, 1), (4, 1, 1)] + schedules = [("false", "false"), ("true", "false"), ("false", "true")] + # SM90 and SM100 + archs = [9, 10] + + # Some kernels may not support all parameters (e.g. only 1 type of schedule), filter them out to prevent duplicates. + generated = set() + + kernel_confs = [] + for arch in archs: + for tM, tN, tK in tiles: + for cM, cN, cK in clusters: + for pong, coop in schedules: + # Co-operative requires tM >= 128 + if tM < 128 and ( + coop == "true" + # This kernel only supports pong OR coop, and not regular warp persistent + or (kernel_name == "bf16bf16bf16_grouped" and pong == "false") + ): + continue + + # This tile size is generally bad + if tM == 256 and tN == 256: + continue + + # To compile less kernels skip pong & coop for smaller tiles as they don't reach the compute roofline + if (pong == "true" or coop == "true") and not ( + tM >= 128 and tN >= 128 + ): + continue + + # SM100 specific + if arch == 10: + # M cluster == 1 requires specific M tile size + if cM == 1 and not (tM == 64 or tM == 128): + continue + + # M cluter > 1 requires N tile >= 32 + if cM > 1 and tN < 32: + continue + + kernel_conf = { + "arch": arch, + "tM": tM, + "tN": tN, + "tK": tK, + "cM": cM, + "cN": cN, + "cK": cK, + "pong": pong, + "coop": coop, + } + kernel_id = gen_kernel_id(kernel_name, kernel_conf) + if kernel_id not in generated: + generated.add(kernel_id) + + kernel_conf["kernel_id"] = kernel_id + kernel_confs.append(kernel_conf) + + return kernel_confs + + +def gen_kernel_id(kernel_name, kernel_conf): + template = KERNEL_ID_TEMPLATES[kernel_name] + return template.format(**kernel_conf) + + +def gen_kernel_file(kernel_name, kernel_conf): + template = FILE_TEMPLATES[kernel_name] + formatted = template.format(**kernel_conf) + + return f"""{COPYRIGHT} +#include "{kernel_name}_common.cuh" + +namespace fbgemm_gpu {{ + +{formatted} + +}} // namespace fbgemm_gpu +""" + + +def gen_kernel_files(kernel_name, kernel_confs, output_dir): + for kernel_conf in kernel_confs: + kernel_id = kernel_conf["kernel_id"] + file_path = os.path.join(output_dir, f"{kernel_id}.cu") + with open(file_path, "w") as f: + f.write(gen_kernel_file(kernel_name, kernel_conf)) + + +def gen_kernel_manifest_decl(kernel_name, kernel_conf): + template = DECL_TEMPLATES[kernel_name] + return template.format(**kernel_conf) + + +def gen_kernel_manifest(kernel_name, kernel_confs, output_dir): + body = "\n".join( + gen_kernel_manifest_decl(kernel_name, kernel_conf) + for kernel_conf in kernel_confs + ) + kernel_map = gen_kernel_map(kernel_name, kernel_confs) + + manifest_content = f"""{COPYRIGHT} +#pragma once + +#include + +namespace fbgemm_gpu {{ + +{body} + +{kernel_map} + +}} // namespace fbgemm_gpu +""" + + manifest_path = os.path.join(output_dir, f"{kernel_name}_manifest.cuh") + with open(manifest_path, "w") as f: + f.write(manifest_content) + + +def main(): + parser = argparse.ArgumentParser(description="Generate kernel files and manifest.") + parser.add_argument( + "--kernel_name", + type=str, + required=True, + help="Name of the kernel to generate, e.g. bf16bf16bf16_grouped.", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Directory to place generated kernels. If unset, will default to the current working directory.", + ) + args = parser.parse_args() + + # Determine the output directory + output_dir = args.output_dir if args.output_dir is not None else os.getcwd() + print(f"Will place generated files in {output_dir}") + + kernel_confs = get_kernel_confs_nv(args.kernel_name) + print(f"Will generate {len(kernel_confs)} kernels") + gen_kernel_files(args.kernel_name, kernel_confs, output_dir) + gen_kernel_manifest(args.kernel_name, kernel_confs, output_dir) + + # Format the generated files + command = f"clang-format -i {output_dir}/*.{{cu,cuh}}" + subprocess.run(command, shell=True) + + +if __name__ == "__main__": + main() diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py new file mode 100644 index 0000000000..ec49a499c5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/scripts/make_heuristic.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from collections import defaultdict +from typing import Dict, List, Set, Tuple + + +def build_heuristic( + file_path: str, threshold: float +) -> Dict[Tuple[int, ...], List[Tuple[str, str]]]: + """ + Builds a heuristic from a set of profiling runs on a kernel. + The heuristic is currently built in a greedy approach: + 1. For all problem shapes consider all kernels performing within (1+threshold) of the fastest kernel. + 2. For the above kernels, count how often it appeared across all problem shapes. + 3. When assigning a kernel to a problem shape, prioritize kernels that appear more often to minimize the number of kernels used. + + Assumptions: + + The ordering of the problem shape dimension is assumed to be (outer_dim, inner_shapes) where outer_dim varies but inner_shapes is fixed in a set of problem shapes. + This is how we decide how to map a set of problem shapes into a heuristic. + E.g. For the following problem shapes (5, 5120, 1024), (6, 5120, 1024), (7, 2048, 1024), (8, 2048, 1024) we would build a heuristic along: + + Problem Shape (5120, 1024): + 5: ... + 6: ... + + Problem Shape (2048, 1024): + 7: ... + 8: ... + """ + # Inner problem shapes + inner_shapes: Set[Tuple[int, ...]] = set() + # Problem Shape -> Best kernel time + best_times_ms: Dict[Tuple[int, ...], float] = {} + # Kernels count across all problem shapes + kernel_count: Dict[str, int] = defaultdict(int) + # Problem Shape -> Candidate kernels + kernel_candidates: Dict[Tuple[int, ...], Set[str]] = defaultdict(set) + # Problem Shape -> Assigned kernel + kernel_assignment: Dict[Tuple[int, ...], str] = {} + # Inner problem shape -> (Outer Dim, Kernel) + heuristics: Dict[Tuple[int, ...], List[Tuple[str, str]]] = {} + + with open(file_path, "r") as file: + parsed_rows = [] + + # Parse CSV and find the best time for each problem shape. + rows = file.readlines() + for row in rows: + problem_shape_str, kernel, time_ms_str = row.split(",") + problem_shape = tuple(int(x) for x in problem_shape_str.split("_")) + time_ms = float(time_ms_str) + + inner_shapes.add(problem_shape[1:]) + best_times_ms[problem_shape] = ( + time_ms + if problem_shape not in best_times_ms + else min(best_times_ms[problem_shape], time_ms) + ) + + parsed_rows.append((problem_shape, kernel, time_ms)) + + # Filter kernels for each problem shape based on the permitted threshold + for problem_shape, kernel, time_ms in parsed_rows: + if time_ms < (best_times_ms[problem_shape] * (1 + threshold)): + kernel_candidates[problem_shape].add(kernel) + kernel_count[kernel] += 1 + + # Prefer kernels that are used more often + kernel_order = sorted(kernel_count.keys(), key=kernel_count.get, reverse=True) + + for kernel in kernel_order: + for problem_shape, candidates in kernel_candidates.items(): + if problem_shape not in kernel_assignment and kernel in candidates: + kernel_assignment[problem_shape] = kernel + + for inner_shape in inner_shapes: + outer_dims_and_kernel = sorted( + [ + (problem_shape[0], kernel) + for problem_shape, kernel in kernel_assignment.items() + if problem_shape[1:] == inner_shape + ], + key=lambda x: x[0], + ) + + heuristic: List[Tuple[str, str]] = [] + last_outer_dim, last_kernel = outer_dims_and_kernel[0] + for outer_dim, assigned_kernel in outer_dims_and_kernel[1:]: + if last_kernel != assigned_kernel: + heuristic.append((str(last_outer_dim), last_kernel)) + last_outer_dim, last_kernel = outer_dim, assigned_kernel + heuristic.append(("else", last_kernel)) + + heuristics[inner_shape] = heuristic + + return heuristics + + +# A basic codegen to make the if statements, customize as needed for your kernel. +def print_heuristic_cpp( + heuristics: Dict[Tuple[int, ...], List[Tuple[str, str]]], + varnames_arg: str, +) -> None: + varnames = dict(enumerate(varnames_arg.split(","))) + + for inner_shape, outer_dims_and_kernels in heuristics.items(): + condition = " && ".join( + [f"{varnames[idx + 1]} == {val}" for idx, val in enumerate(inner_shape)] + ) + print(f"if ({condition}) {{") + for idx, (outer_dim, kernel) in enumerate(outer_dims_and_kernels): + if idx == 0: + print(f" if ({varnames[0]} <= {outer_dim}) {{") + elif idx == len(outer_dims_and_kernels) - 1: + print(" } else {") + else: + print(f" }} else if ({varnames[0]} <= {outer_dim}) {{") + print(f" return {kernel};") + print(" }") + print("}\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate heuristics from FBGEMM kernel tuning cache detailed info." + ) + parser.add_argument( + "--file-path", + type=str, + required=True, + help="Path to the input data file generated by the FBGEMM tuning cache.", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.01, + help="Kernels performing within --threshold of the best fastest kernel will be considered.", + ) + parser.add_argument( + "--cpp", action="store_true", help="Generate C++ code for the heuristic." + ) + parser.add_argument( + "--cpp_varnames", + type=str, + help="Variable names to use for C++ heuristic generation, comma separated in same order of the problem shape. E.g. M,N,K", + ) + + args = parser.parse_args() + + heuristic = build_heuristic(args.file_path, args.threshold) + if args.cpp: + if args.cpp_varnames is None: + print("If setting --cpp must also set --cpp_varnames.") + exit(1) + print_heuristic_cpp(heuristic, args.cpp_varnames) + else: + for inner_shape, outer_dims in heuristic.items(): + print(f"Problem Shape: {inner_shape}") + for outer_dim, assigned_kernel in outer_dims: + print(f" {outer_dim}: {assigned_kernel}") + + +if __name__ == "__main__": + main() diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/common/utils.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/utils.cpp new file mode 100644 index 0000000000..9bdb0d5e74 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/common/utils.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/quantize/utils.h" // @manual + +namespace fbgemm_gpu { + +int nextPowerOf2(int n) { + if (n == 0) { + return 1; + } + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n + 1; +} + +} // namespace fbgemm_gpu From 71020f1570975e8be90407f328432f3a96e66ef1 Mon Sep 17 00:00:00 2001 From: cthi Date: Fri, 13 Jun 2025 14:09:54 -0700 Subject: [PATCH 2/4] Support tuning cache for Cutlass BF16 grouped GEMM Differential Revision: D75541013 --- .../bf16bf16bf16_grouped.cu | 155 ++++++++++-------- .../bf16bf16bf16_grouped_manifest.cuh | 17 ++ 2 files changed, 107 insertions(+), 65 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index b1a57b33dd..f1a3ceccc4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -10,59 +10,52 @@ #include #include "bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh" +#include "fbgemm_gpu/quantize/tuning_cache.hpp" +#include "fbgemm_gpu/quantize/utils.h" namespace fbgemm_gpu { #if CUDART_VERSION >= 12000 -// BF16 grouped cutlass kernel dispatch. +namespace { +TuningCache& getTuningCache() { + // This kernel has multiple APIs templated based on InputType, so we use this + // to have a single cache instance across APIs. + static TuningCache cache("bf16bf16bf16_grouped"); + return cache; +} +} // namespace + template -at::Tensor dispatch_bf16_grouped_kernel( - int G, - int total_M, - int N, - int K, - InputType X, // BF16 - InputType W, // BF16 - at::Tensor output, - std::optional zero_start_index_M = std::nullopt, - std::optional M_sizes = std::nullopt) { +Kernel_bf16bf16bf16_grouped +get_kernel_via_heuristic(int G, int total_M, int N, int K) { // Use heuristics to pick best kernel implementation. // Llama4 128E if (G == 128) { if (N == 5120 && K == 1024) { if (total_M <= 128) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 256) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_2_1_1_t; } else if (total_M <= 2048) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 4096) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_2_1_1_f; } else if (total_M <= 8192) { - return bf16bf16bf16_grouped_128_64_128_1_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_64_128_1_1_1_f; } else if (total_M <= 16384) { - return bf16bf16bf16_grouped_128_128_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_128_128_2_1_1_t; } else { - return bf16bf16bf16_grouped_128_256_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_256_128_2_1_1_f; } } if (N == 2048 && K == 5120) { if (total_M <= 2048) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else { - return bf16bf16bf16_grouped_128_128_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_128_128_2_1_1_t; } } } @@ -71,71 +64,103 @@ at::Tensor dispatch_bf16_grouped_kernel( if (G == 16) { if (N == 5120 && K == 1024) { if (total_M <= 32) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 64) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_2_1_1_t; } else if (total_M <= 256) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 512) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_2_1_1_t; } else if (total_M <= 1024) { - return bf16bf16bf16_grouped_128_64_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_64_128_2_1_1_t; } else { - return bf16bf16bf16_grouped_128_256_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_256_128_2_1_1_f; } } if (N == 2048 && K == 5120) { if (total_M <= 16) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 64) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_2_1_1_f; } else if (total_M <= 256) { - return bf16bf16bf16_grouped_128_16_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 512) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_2_1_1_f; } else if (total_M <= 1024) { - return bf16bf16bf16_grouped_128_64_128_1_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_64_128_1_1_1_f; } else { - return bf16bf16bf16_grouped_128_128_128_2_1_1_t( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_128_128_2_1_1_t; } } } // Fallback to legacy heuristic for now. if (total_M <= 16) { - return bf16bf16bf16_grouped_128_16_128_1_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_16_128_1_1_1_f; } else if (total_M <= 32) { - return bf16bf16bf16_grouped_128_32_128_1_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_32_128_1_1_1_f; } else if (total_M <= 64) { - return bf16bf16bf16_grouped_128_64_128_1_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_64_128_1_1_1_f; } else if (total_M <= 128) { - return bf16bf16bf16_grouped_128_128_128_1_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_128_128_1_1_1_f; } else if (total_M <= 512) { - return bf16bf16bf16_grouped_256_128_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_256_128_128_2_1_1_f; } else { - return bf16bf16bf16_grouped_128_256_128_2_1_1_f( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_128_256_128_2_1_1_f; } } +template +Kernel_bf16bf16bf16_grouped get_kernel_via_tuning( + int G, + int total_M, + int N, + int K, + InputType X, // BF16 + InputType W, // BF16 + at::Tensor output, + std::optional zero_start_index_M = std::nullopt, + std::optional M_sizes = std::nullopt) { + auto& cache = getTuningCache(); + + // Reducing amount of auto tuning by rounding up total_m to next power of 2. + total_M = nextPowerOf2(total_M); + // Use (total_M, N, K, G) shape as the key. + const std::string shape_key = std::to_string(total_M) + "_" + + std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G); + const auto& kernels = get_bf16bf16bf16_grouped_kernels(); + auto kernel = cache.findBestKernelMaybeAutotune( + shape_key, kernels, X, W, output, zero_start_index_M, M_sizes); + + return kernel; +} + +// BF16 grouped cutlass kernel dispatch. +template +at::Tensor dispatch_bf16_grouped_kernel( + int G, + int total_M, + int N, + int K, + InputType X, // BF16 + InputType W, // BF16 + at::Tensor output, + std::optional zero_start_index_M = std::nullopt, + std::optional M_sizes = std::nullopt) { + // Select kernel to run via heuristics or tuning. + auto kernel = [&]() { + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { + return get_kernel_via_tuning( + G, total_M, N, K, X, W, output, zero_start_index_M, M_sizes); + } else { + return get_kernel_via_heuristic(G, total_M, N, K); + } + }(); + // Invoke kernel + return kernel(X, W, output, zero_start_index_M, M_sizes); +} + template OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) { at::Tensor Y; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh index 88c27de20b..43532fec4c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh @@ -180,4 +180,21 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f( std::optional zero_start_index_M, std::optional M_sizes); +template +using Kernel_bf16bf16bf16_grouped = at::Tensor (*)( + InputType, + InputType, + at::Tensor, + std::optional, + std::optional); + +template +const std::unordered_map>& +get_bf16bf16bf16_grouped_kernels() { + static const std:: + unordered_map> + kernels = {}; + return kernels; +} + } // namespace fbgemm_gpu From 0a5666758cbdce75117720d9bceb172bead242f1 Mon Sep 17 00:00:00 2001 From: cthi Date: Fri, 13 Jun 2025 14:09:54 -0700 Subject: [PATCH 3/4] Add new kernels for Cutlass BF16 grouped GEMM for tuning cache Differential Revision: D75806957 --- .../bf16bf16bf16_grouped.cu | 8 +- ...f16bf16bf16_grouped_128_128_128_1_1_1_t.cu | 40 ++ ...f16bf16bf16_grouped_128_128_128_2_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_128_128_128_4_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_128_128_128_4_1_1_t.cu | 40 ++ ...f16bf16bf16_grouped_128_16_128_4_1_1_f.cu} | 17 +- ...f16bf16bf16_grouped_128_256_128_1_1_1_t.cu | 40 ++ ...f16bf16bf16_grouped_128_256_128_2_1_1_t.cu | 40 ++ ...f16bf16bf16_grouped_128_256_128_4_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_128_256_128_4_1_1_t.cu | 40 ++ ...bf16bf16bf16_grouped_128_32_128_4_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_128_64_128_2_1_1_f.cu} | 17 +- ...bf16bf16bf16_grouped_128_64_128_4_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_256_128_128_1_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_256_128_128_1_1_1_t.cu | 40 ++ ...f16bf16bf16_grouped_256_128_128_2_1_1_t.cu | 40 ++ ...f16bf16bf16_grouped_256_128_128_4_1_1_f.cu | 40 ++ ...f16bf16bf16_grouped_256_128_128_4_1_1_t.cu | 40 ++ ...bf16bf16bf16_grouped_256_16_128_1_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_16_128_2_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_16_128_4_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_32_128_1_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_32_128_2_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_32_128_4_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_64_128_1_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_64_128_2_1_1_f.cu | 40 ++ ...bf16bf16bf16_grouped_256_64_128_4_1_1_f.cu | 40 ++ .../bf16bf16bf16_grouped_manifest.cuh | 443 +++++++++++++++++- 28 files changed, 1414 insertions(+), 31 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_t.cu rename fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/{bf16bf16bf16_grouped_128_32_128_2_1_1_t.cu => bf16bf16bf16_grouped_128_16_128_4_1_1_f.cu} (66%) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_f.cu rename fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/{bf16bf16bf16_grouped_128_64_128_2_1_1_t.cu => bf16bf16bf16_grouped_128_64_128_2_1_1_f.cu} (72%) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_2_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_1_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_2_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_4_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_4_1_1_f.cu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index f1a3ceccc4..6652b404d2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -37,7 +37,7 @@ get_kernel_via_heuristic(int G, int total_M, int N, int K) { if (total_M <= 128) { return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 256) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_t; + return bf16bf16bf16_grouped_128_32_128_2_1_1_f; } else if (total_M <= 2048) { return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 4096) { @@ -66,13 +66,13 @@ get_kernel_via_heuristic(int G, int total_M, int N, int K) { if (total_M <= 32) { return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 64) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_t; + return bf16bf16bf16_grouped_128_32_128_2_1_1_f; } else if (total_M <= 256) { return bf16bf16bf16_grouped_128_16_128_2_1_1_f; } else if (total_M <= 512) { - return bf16bf16bf16_grouped_128_32_128_2_1_1_t; + return bf16bf16bf16_grouped_128_32_128_2_1_1_f; } else if (total_M <= 1024) { - return bf16bf16bf16_grouped_128_64_128_2_1_1_t; + return bf16bf16bf16_grouped_128_64_128_2_1_1_f; } else { return bf16bf16bf16_grouped_128_256_128_2_1_1_f; } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_t.cu new file mode 100644 index 0000000000..f6b598a3f3 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 128, + 128, + 1, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_f.cu new file mode 100644 index 0000000000..68dc950909 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 128, + 128, + 2, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_f.cu new file mode 100644 index 0000000000..5fa95a7fd6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 128, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_t.cu new file mode 100644 index 0000000000..a36aed7cd6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 128, + 128, + 4, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_16_128_4_1_1_f.cu similarity index 66% rename from fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_t.cu rename to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_16_128_4_1_1_f.cu index 6b31ee8096..c3989aa4bc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_16_128_4_1_1_f.cu @@ -10,24 +10,31 @@ namespace fbgemm_gpu { -at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( + return bf16bf16bf16_grouped_impl( X, W, output, zero_start_index_M, M_sizes); } -at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_f( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 16, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_t.cu new file mode 100644 index 0000000000..ea79e70d61 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 256, + 128, + 1, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_t.cu new file mode 100644 index 0000000000..cf637757af --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 256, + 128, + 2, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_f.cu new file mode 100644 index 0000000000..91c3ba38a7 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 256, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_t.cu new file mode 100644 index 0000000000..03a9b2f987 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_4_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 256, + 128, + 4, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_f.cu new file mode 100644 index 0000000000..faec7f4d90 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 32, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_1_1_f.cu similarity index 72% rename from fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_1_1_t.cu rename to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_1_1_f.cu index bb275da464..cf58b42427 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_1_1_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_1_1_f.cu @@ -10,24 +10,31 @@ namespace fbgemm_gpu { -at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( + return bf16bf16bf16_grouped_impl( X, W, output, zero_start_index_M, M_sizes); } -at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_f( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 64, + 128, + 2, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_f.cu new file mode 100644 index 0000000000..2a813ddd2d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 128, + 64, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_f.cu new file mode 100644 index 0000000000..485bdca6c4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 128, + 128, + 1, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_t.cu new file mode 100644 index 0000000000..665eee817c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 128, + 128, + 1, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_2_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_2_1_1_t.cu new file mode 100644 index 0000000000..e1bb65c9aa --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_2_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 128, + 128, + 2, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_f.cu new file mode 100644 index 0000000000..668c490636 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 128, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_t.cu new file mode 100644 index 0000000000..fcc3639376 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_4_1_1_t.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 128, + 128, + 4, + 1, + 1, + true>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_1_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_1_1_1_f.cu new file mode 100644 index 0000000000..797ab2ad34 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_1_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_16_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_16_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 16, + 128, + 1, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_2_1_1_f.cu new file mode 100644 index 0000000000..2d895aa2e5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_2_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_16_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_16_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 16, + 128, + 2, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_4_1_1_f.cu new file mode 100644 index 0000000000..c8ff2ddbd2 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_16_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_16_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_16_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 16, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_f.cu new file mode 100644 index 0000000000..4c730eb258 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 32, + 128, + 1, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_f.cu new file mode 100644 index 0000000000..29ddfe76f2 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 32, + 128, + 2, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_1_1_f.cu new file mode 100644 index 0000000000..3b1852a983 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_32_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_32_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 32, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_f.cu new file mode 100644 index 0000000000..3f57bb00b9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 64, + 128, + 1, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_f.cu new file mode 100644 index 0000000000..eb3cc42712 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 64, + 128, + 2, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_4_1_1_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_4_1_1_f.cu new file mode 100644 index 0000000000..e4b036d634 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_4_1_1_f.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_256_64_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl( + X, W, output, zero_start_index_M, M_sizes); +} + +at::Tensor bf16bf16bf16_grouped_256_64_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes) { + return bf16bf16bf16_grouped_impl< + at::TensorList, + 256, + 64, + 128, + 4, + 1, + 1, + false>(X, W, output, zero_start_index_M, M_sizes); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh index 43532fec4c..248851091d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh @@ -12,6 +12,20 @@ namespace fbgemm_gpu { +at::Tensor bf16bf16bf16_grouped_128_16_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_16_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 @@ -26,6 +40,34 @@ at::Tensor bf16bf16bf16_grouped_128_16_128_2_1_1_f( std::optional zero_start_index_M, std::optional M_sizes); +at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_16_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 @@ -40,14 +82,98 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_f( std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_f( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, @@ -68,84 +194,266 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_t( std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_f( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_t( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_64_128_2_1_1_t( +at::Tensor bf16bf16bf16_grouped_128_128_128_4_1_1_t( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_16_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_16_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_f( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_t( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_t( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_t( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_f( +at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_128_256_128_4_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_16_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_16_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_16_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_16_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_16_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_16_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_32_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_32_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_64_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_64_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_t( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, @@ -166,14 +474,42 @@ at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_f( std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f( +at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_t( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor output, std::optional zero_start_index_M, std::optional M_sizes); -at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_f( +at::Tensor bf16bf16bf16_grouped_256_128_128_2_1_1_t( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_f( + at::TensorList X, // BF16 + at::TensorList W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor output, + std::optional zero_start_index_M, + std::optional M_sizes); + +at::Tensor bf16bf16bf16_grouped_256_128_128_4_1_1_t( at::TensorList X, // BF16 at::TensorList W, // BF16 at::Tensor output, @@ -193,7 +529,80 @@ const std::unordered_map>& get_bf16bf16bf16_grouped_kernels() { static const std:: unordered_map> - kernels = {}; + kernels = { + {"bf16bf16bf16_grouped_128_16_128_1_1_1_f", + bf16bf16bf16_grouped_128_16_128_1_1_1_f}, + {"bf16bf16bf16_grouped_128_16_128_2_1_1_f", + bf16bf16bf16_grouped_128_16_128_2_1_1_f}, + {"bf16bf16bf16_grouped_128_16_128_4_1_1_f", + bf16bf16bf16_grouped_128_16_128_4_1_1_f}, + {"bf16bf16bf16_grouped_128_32_128_1_1_1_f", + bf16bf16bf16_grouped_128_32_128_1_1_1_f}, + {"bf16bf16bf16_grouped_128_32_128_2_1_1_f", + bf16bf16bf16_grouped_128_32_128_2_1_1_f}, + {"bf16bf16bf16_grouped_128_32_128_4_1_1_f", + bf16bf16bf16_grouped_128_32_128_4_1_1_f}, + {"bf16bf16bf16_grouped_128_64_128_1_1_1_f", + bf16bf16bf16_grouped_128_64_128_1_1_1_f}, + {"bf16bf16bf16_grouped_128_64_128_2_1_1_f", + bf16bf16bf16_grouped_128_64_128_2_1_1_f}, + {"bf16bf16bf16_grouped_128_64_128_4_1_1_f", + bf16bf16bf16_grouped_128_64_128_4_1_1_f}, + {"bf16bf16bf16_grouped_128_128_128_1_1_1_f", + bf16bf16bf16_grouped_128_128_128_1_1_1_f}, + {"bf16bf16bf16_grouped_128_128_128_1_1_1_t", + bf16bf16bf16_grouped_128_128_128_1_1_1_t}, + {"bf16bf16bf16_grouped_128_128_128_2_1_1_f", + bf16bf16bf16_grouped_128_128_128_2_1_1_f}, + {"bf16bf16bf16_grouped_128_128_128_2_1_1_t", + bf16bf16bf16_grouped_128_128_128_2_1_1_t}, + {"bf16bf16bf16_grouped_128_128_128_4_1_1_f", + bf16bf16bf16_grouped_128_128_128_4_1_1_f}, + {"bf16bf16bf16_grouped_128_128_128_4_1_1_t", + bf16bf16bf16_grouped_128_128_128_4_1_1_t}, + {"bf16bf16bf16_grouped_128_256_128_1_1_1_f", + bf16bf16bf16_grouped_128_256_128_1_1_1_f}, + {"bf16bf16bf16_grouped_128_256_128_1_1_1_t", + bf16bf16bf16_grouped_128_256_128_1_1_1_t}, + {"bf16bf16bf16_grouped_128_256_128_2_1_1_f", + bf16bf16bf16_grouped_128_256_128_2_1_1_f}, + {"bf16bf16bf16_grouped_128_256_128_2_1_1_t", + bf16bf16bf16_grouped_128_256_128_2_1_1_t}, + {"bf16bf16bf16_grouped_128_256_128_4_1_1_f", + bf16bf16bf16_grouped_128_256_128_4_1_1_f}, + {"bf16bf16bf16_grouped_128_256_128_4_1_1_t", + bf16bf16bf16_grouped_128_256_128_4_1_1_t}, + {"bf16bf16bf16_grouped_256_16_128_1_1_1_f", + bf16bf16bf16_grouped_256_16_128_1_1_1_f}, + {"bf16bf16bf16_grouped_256_16_128_2_1_1_f", + bf16bf16bf16_grouped_256_16_128_2_1_1_f}, + {"bf16bf16bf16_grouped_256_16_128_4_1_1_f", + bf16bf16bf16_grouped_256_16_128_4_1_1_f}, + {"bf16bf16bf16_grouped_256_32_128_1_1_1_f", + bf16bf16bf16_grouped_256_32_128_1_1_1_f}, + {"bf16bf16bf16_grouped_256_32_128_2_1_1_f", + bf16bf16bf16_grouped_256_32_128_2_1_1_f}, + {"bf16bf16bf16_grouped_256_32_128_4_1_1_f", + bf16bf16bf16_grouped_256_32_128_4_1_1_f}, + {"bf16bf16bf16_grouped_256_64_128_1_1_1_f", + bf16bf16bf16_grouped_256_64_128_1_1_1_f}, + {"bf16bf16bf16_grouped_256_64_128_2_1_1_f", + bf16bf16bf16_grouped_256_64_128_2_1_1_f}, + {"bf16bf16bf16_grouped_256_64_128_4_1_1_f", + bf16bf16bf16_grouped_256_64_128_4_1_1_f}, + {"bf16bf16bf16_grouped_256_128_128_1_1_1_f", + bf16bf16bf16_grouped_256_128_128_1_1_1_f}, + {"bf16bf16bf16_grouped_256_128_128_1_1_1_t", + bf16bf16bf16_grouped_256_128_128_1_1_1_t}, + {"bf16bf16bf16_grouped_256_128_128_2_1_1_f", + bf16bf16bf16_grouped_256_128_128_2_1_1_f}, + {"bf16bf16bf16_grouped_256_128_128_2_1_1_t", + bf16bf16bf16_grouped_256_128_128_2_1_1_t}, + {"bf16bf16bf16_grouped_256_128_128_4_1_1_f", + bf16bf16bf16_grouped_256_128_128_4_1_1_f}, + {"bf16bf16bf16_grouped_256_128_128_4_1_1_t", + bf16bf16bf16_grouped_256_128_128_4_1_1_t}, + }; return kernels; } From 026366c1985d64d857528e8242478c0fd9489d9a Mon Sep 17 00:00:00 2001 From: Chris Thi Date: Fri, 13 Jun 2025 14:21:29 -0700 Subject: [PATCH 4/4] Support tuning cache for Cutlass FP8 GEMM (#4301) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4301 X-link: https://github.com/facebookresearch/FBGEMM/pull/1377 This diff adds support for the tuning cache to the kernel. There should be no performance changes to the existing heuristics. - I refactored the kernel dispatch logic to instead return the kernel function, as it removes some duplication of the kernel invoke. - The next diff in this stack will add the new kernels D75820688, to make the review easier - Note that we are having some issues with adding the new kernels, as I have found this kernel is actually compiling 12 variants for each configuration, see D75820688 for more context. So for now we won't add the new kernels in D75820688, but we can just onboard it to auto tuning incase someone wants to compile them locally. Will revisit D75820688 later. Reviewed By: q10, jiawenliu64 Differential Revision: D75541025 --- .../cutlass_extensions/f8f8bf16_rowwise.cu | 216 ++++++++++-------- .../f8f8bf16_rowwise_manifest.cuh | 23 ++ 2 files changed, 147 insertions(+), 92 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu index da4a650f78..4c292ce02d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu @@ -12,164 +12,196 @@ // clang-format on #include "f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh" +#include "fbgemm_gpu/quantize/tuning_cache.hpp" +#include "fbgemm_gpu/quantize/utils.h" namespace fbgemm_gpu { #if CUDART_VERSION >= 12000 // FP8 Rowwise Cutlass kernel dispatch. -at::Tensor dispatch_fp8_rowwise_kernel( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - bool use_fast_accum, - std::optional bias = std::nullopt, - std::optional output = std::nullopt) { - int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); - int N = size_to_dim_(WQ.dim() - 1, WQ.sizes()); - int K = XQ.size(-1); - static int arch = -1; - // Avoid expensive cudaGetDeviceProperties call. - if (arch < 0) { - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - if (prop.major >= 10) { - arch = 10; - int runtimeVersion; - C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); - TORCH_CHECK( - runtimeVersion >= 12080, - "FP8 GEMM on sm100a or above requires cuda >= 12.8"); - } else { - arch = 9; - } - } - +Kernel_f8f8bf16_rowwise +get_kernel_via_heuristic(int arch, int M, int N, int K, bool use_fast_accum) { // Use shape heuristics to dispatch to optimized kernel configuration. if (arch == 10) { if (M <= 128) { if (N <= 1024) { - return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f; } else { - return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f; } } else if (M <= 1024) { if (N <= 1024) { - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; } else { - return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f; } } else if (M <= 2048) { - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; } else { if (N <= 1024) { - return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f; } else { - return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f; } } } else { if (M <= 16) { - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; } else if (M <= 32) { if (N <= 4096) { - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; } else { - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; } } else if (M <= 64) { if (N <= 2048) { - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; } else if (N <= 4096) { - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; } else { - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; } } else if (M <= 128) { if (N <= 1024) { - return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f; } else if (N <= 2048) { - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; } else if (N <= 4096) { - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; } else { - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; } } else if (M <= 256) { if (N <= 1024) { - return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f; } else if (N <= 2048) { - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; } else if (N <= 4096) { - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; } else { - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; } } else if (M <= 512) { if (N <= 1024) { - return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f; } else if (N <= 2048) { - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; } else if (N <= 4096 || use_fast_accum == false) { - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; } else { - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; } } else if (M <= 1024) { if (N <= 1024) { - return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f; } else if (N <= 2048 || use_fast_accum == false) { - return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f; } else { - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; } } else { if (M <= 2048 && N <= 1024) { - return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f; } else if (K <= 4096 || use_fast_accum == false) { - return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f; } else if (M > 8192 && N > 8192) { - return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t; } else { - return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( - XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); + return f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t; } } } } +Kernel_f8f8bf16_rowwise get_kernel_via_tuning( + int arch, + int M, + int N, + int K, + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + bool use_fast_accum, + std::optional bias = std::nullopt, + std::optional output = std::nullopt) { + // One cache per kernel type + static TuningCache cache("f8f8bf16_rowwise"); + + // Reducing amount of auto tuning by rounding up M to next power of 2. + M = nextPowerOf2(M); + // Use (M, N, K) shape as the key. + const std::string shape_key = + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); + const auto& kernels = get_f8f8bf16_rowwise_kernels(arch); + auto kernel = cache.findBestKernelMaybeAutotune( + shape_key, + kernels, + XQ, + WQ, + x_scale, + w_scale, + use_fast_accum, + bias, + output); + + return kernel; +} + +// FP8 Rowwise Cutlass kernel dispatch. +at::Tensor dispatch_fp8_rowwise_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + bool use_fast_accum, + std::optional bias = std::nullopt, + std::optional output = std::nullopt) { + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = size_to_dim_(WQ.dim() - 1, WQ.sizes()); + int K = XQ.size(-1); + + static int arch = -1; + // Avoid expensive cudaGetDeviceProperties call. + if (arch < 0) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + if (prop.major >= 10) { + arch = 10; + int runtimeVersion; + C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); + TORCH_CHECK( + runtimeVersion >= 12080, + "FP8 GEMM on sm100a or above requires cuda >= 12.8"); + } else { + arch = 9; + } + } + + // Select kernel to run via heuristics or tuning. + auto kernel = [&]() { + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { + return get_kernel_via_tuning( + arch, + M, + N, + K, + XQ, + WQ, + x_scale, + w_scale, + use_fast_accum, + bias, + output); + } else { + return get_kernel_via_heuristic(arch, M, N, K, use_fast_accum); + } + }(); + // Invoke kernel + return kernel(XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); +} + void f8f8bf16_rowwise_out( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh index 676a1d66b4..ecbdd72c85 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh @@ -135,4 +135,27 @@ at::Tensor f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( bool use_fast_accum = true, std::optional bias = std::nullopt, std::optional output = std::nullopt); + +using Kernel_f8f8bf16_rowwise = at::Tensor (*)( + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + bool, + std::optional, + std::optional); + +inline const std::unordered_map& +get_f8f8bf16_rowwise_kernels(int arch) { + static const std::unordered_map + kernelsSM90 = {}; + static const std::unordered_map + kernelsSM100 = {}; + if (arch == 10) { + return kernelsSM100; + } else { + return kernelsSM90; + } +} + } // namespace fbgemm_gpu