diff --git a/.gitignore b/.gitignore index ccc868826..a30a2dc84 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ ShareGPT_V3_unfiltered_cleaned_split.json .vscode/settings.json +ibm-triton-lib/ibm_triton_lib.egg-info/ + diff --git a/.gitmodules b/.gitmodules index 9880c2f9b..2c9aa057e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "vllm"] path = vllm url = https://github.com/vllm-project/vllm.git +[submodule "third_party/fmwork"] + path = third_party/fmwork + url = git@github.com:bringlein/fmwork.git diff --git a/Dockerfile b/Dockerfile index 38274a7cc..0e5fe148b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ ARG BASE_UBI_IMAGE_TAG=9.4 ARG PYTHON_VERSION=3.12 ARG MAX_JOBS=64 ARG PIP_VLLM_VERSION=0.8.1 +# TODO add ARG CUDA_VERSION=12-8 ARG VLLM_SOURCE=pip # or VLLM_SOURCE=custom @@ -122,6 +123,31 @@ ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ python3 setup.py bdist_wheel --dist-dir=/workspace/ +# ## flashinfer Builder ################################################################# +# FROM vllm-builder_custom AS flashinfer-builder +# ARG MAX_JOBS +# +# # # build deps? +# # RUN --mount=type=cache,target=/root/.cache/pip \ +# # --mount=type=cache,target=/root/.cache/uv \ +# # uv pip install ninja cmake wheel pybind11 setuptools +# +# WORKDIR /workspace/flashinfer +# RUN git clone --recursive https://github.com/flashinfer-ai/flashinfer.git +# +# ENV TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' +# ENV FLASHINFER_ENABLE_SM90=1 +# RUN --mount=type=cache,target=/root/.cache/pip \ +# cd flashinfer \ +# && export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} export FLASHINFER_ENABLE_SM90=${FLASHINFER_ENABLE_SM90} \ +# && python -m flashinfer.aot \ +# && python -m build --no-isolation --wheel +# +# # uv pip install \ +# # --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" +# +# RUN ls -al /workspace/flashinfer/flashinfer/dist + ## Runtime ################################################################# FROM base AS runtime @@ -227,20 +253,54 @@ RUN --mount=type=cache,target=/root/.cache/pip \ uv pip install pytest llnl-hatchet debugpy # Install FlashInfer -RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ - echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment - -RUN --mount=type=cache,target=/root/.cache/pip \ - . /etc/environment && \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl +# RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ +# echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment +# RUN --mount=type=cache,target=/root/.cache/pip \ +# . /etc/environment && \ +# python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl +# RUN --mount=type=cache,target=/root/.cache/pip \ +# . /etc/environment && \ +# uv pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl +# RUN --mount=type=cache,target=/root/.cache/pip \ +# uv pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/ --no-deps +# RUN --mount=type=cache,target=/root/.cache/pip \ +# --mount=type=cache,target=/root/.cache/uv \ +# uv pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be +# RUN --mount=type=cache,target=/root/.cache/pip \ +# --mount=type=cache,target=/root/.cache/uv \ +# uv pip install torch==2.7.0 +# RUN --mount=type=cache,target=/root/.cache/pip \ +# --mount=type=cache,target=/root/.cache/uv \ +# uv pip install https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl +# RUN mkdir /workspace/flashinfer_dist && ls -al /workspace/flashinfer_dist +# COPY --from=flashinfer-builder /workspace/*.whl /workspace/flashinfer_dist +# RUN --mount=type=cache,target=/root/.cache/pip \ +# --mount=type=cache,target=/root/.cache/uv \ +# uv pip install /workspace/flashinfer_dist/*.whl +# TODO: we need nvcc for flashinfer installation...custom build fails, see above +RUN curl -Lo /etc/yum.repos.d/cuda-rhel9.repo \ + https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo +RUN microdnf install -y \ + cuda-nvcc-12-8 cuda-nvtx-12-8 cuda-libraries-devel-12-8 && \ + microdnf clean all +ENV CUDA_HOME="/usr/local/cuda" \ + PATH="${CUDA_HOME}/bin:${PATH}" \ + LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" +ENV TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' +ENV FLASHINFER_ENABLE_SM90=1 +RUN TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} FLASHINFER_ENABLE_SM90=${FLASHINFER_ENABLE_SM90} uv pip install \ + --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" + RUN ln -s ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 ${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_cupti/lib/libcupti.so RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness && cd lm-evaluation-harness && uv pip install . -RUN git clone --depth 1 https://github.com/IBM/fmwork.git +# RUN git clone --depth 1 https://github.com/IBM/fmwork.git +# RUN git clone --depth 1 https://github.com/IBM/fmwork.git +COPY third_party/fmwork fmwork ENV STORE_TEST_RESULT_PATH=/results @@ -250,7 +310,7 @@ COPY vllm/tests tests COPY ShareGPT_V3_unfiltered_cleaned_split.json ShareGPT_V3_unfiltered_cleaned_split.json # Copy thid-party kernels and insert into path -COPY third_party third_party +COPY third_party/kernels third_party ENV PYTHONPATH /workspace # see https://github.com/IBM/triton-dejavu?tab=readme-ov-file#environment-variables diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_bmm_chunk_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-25b6b5e18b4b4e9d94bc6cfc6e07052ef952503581ca3a6592f943790d859cd8/tune_features-b815cf0dca1de8dc8520ba45f9861122ec38d2b40655a5044d5da8dee5b249cf/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_bmm_chunk_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-25b6b5e18b4b4e9d94bc6cfc6e07052ef952503581ca3a6592f943790d859cd8/tune_features-b815cf0dca1de8dc8520ba45f9861122ec38d2b40655a5044d5da8dee5b249cf/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json new file mode 100755 index 000000000..efcde2e45 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_bmm_chunk_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-25b6b5e18b4b4e9d94bc6cfc6e07052ef952503581ca3a6592f943790d859cd8/tune_features-b815cf0dca1de8dc8520ba45f9861122ec38d2b40655a5044d5da8dee5b249cf/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json @@ -0,0 +1,26 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_bmm:_bmm_chunk_fwd_kernel)", + "total_bench_time_s": 4.903317928314209, + "evaluated_configs": 9, + "keys": [ + "chunk_size", + "K", + "IS_CAUSAL" + ], + "cache": { + "('256', '128', 'False', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.int32')": "BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, num_warps: 2, num_ctas: 1, num_stages: 5, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('256', '128', 'False', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.int32')": [ + 0.007391999941319227 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": false + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_bmm_chunk_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-25b6b5e18b4b4e9d94bc6cfc6e07052ef952503581ca3a6592f943790d859cd8/tune_features-b815cf0dca1de8dc8520ba45f9861122ec38d2b40655a5044d5da8dee5b249cf/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_bmm_chunk_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-25b6b5e18b4b4e9d94bc6cfc6e07052ef952503581ca3a6592f943790d859cd8/tune_features-b815cf0dca1de8dc8520ba45f9861122ec38d2b40655a5044d5da8dee5b249cf/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json new file mode 100755 index 000000000..0312c1d6e --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_bmm_chunk_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-25b6b5e18b4b4e9d94bc6cfc6e07052ef952503581ca3a6592f943790d859cd8/tune_features-b815cf0dca1de8dc8520ba45f9861122ec38d2b40655a5044d5da8dee5b249cf/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json @@ -0,0 +1,26 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_bmm:_bmm_chunk_fwd_kernel)", + "total_bench_time_s": 10756.567904472351, + "evaluated_configs": 2625, + "keys": [ + "chunk_size", + "K", + "IS_CAUSAL" + ], + "cache": { + "('256', '128', 'False', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.int32')": "BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('256', '128', 'False', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.int32')": [ + 0.002230335958302021 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_cumsum_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-2fa507d0842a5f6a78eee941dc3c3a68f89756b47913aff39d4208afafb074fa/tune_features-604fd79069d101d891a5ad1f1f001551ff096d4dea3dc2c159faa57a9430d214/kernel_configs-86c110801e8443207d93837dc53554c59f26ccf5a1a04c352ea7e8587c82d89e/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_cumsum_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-2fa507d0842a5f6a78eee941dc3c3a68f89756b47913aff39d4208afafb074fa/tune_features-604fd79069d101d891a5ad1f1f001551ff096d4dea3dc2c159faa57a9430d214/kernel_configs-86c110801e8443207d93837dc53554c59f26ccf5a1a04c352ea7e8587c82d89e/default/cache.json new file mode 100755 index 000000000..df74c1fe3 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_cumsum_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-2fa507d0842a5f6a78eee941dc3c3a68f89756b47913aff39d4208afafb074fa/tune_features-604fd79069d101d891a5ad1f1f001551ff096d4dea3dc2c159faa57a9430d214/kernel_configs-86c110801e8443207d93837dc53554c59f26ccf5a1a04c352ea7e8587c82d89e/default/cache.json @@ -0,0 +1,25 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_cumsum_fwd_kernel)", + "total_bench_time_s": 7.295067548751831, + "evaluated_configs": 7, + "keys": [ + "chunk_size", + "nheads" + ], + "cache": { + "('256', '128', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32')": "BLOCK_SIZE_H: 2, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('256', '128', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32')": [ + 0.007071999832987785 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": false + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_cumsum_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-2fa507d0842a5f6a78eee941dc3c3a68f89756b47913aff39d4208afafb074fa/tune_features-604fd79069d101d891a5ad1f1f001551ff096d4dea3dc2c159faa57a9430d214/kernel_configs-86c110801e8443207d93837dc53554c59f26ccf5a1a04c352ea7e8587c82d89e/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_cumsum_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-2fa507d0842a5f6a78eee941dc3c3a68f89756b47913aff39d4208afafb074fa/tune_features-604fd79069d101d891a5ad1f1f001551ff096d4dea3dc2c159faa57a9430d214/kernel_configs-86c110801e8443207d93837dc53554c59f26ccf5a1a04c352ea7e8587c82d89e/default/cache.json new file mode 100755 index 000000000..e6e0dc8a0 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_cumsum_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-2fa507d0842a5f6a78eee941dc3c3a68f89756b47913aff39d4208afafb074fa/tune_features-604fd79069d101d891a5ad1f1f001551ff096d4dea3dc2c159faa57a9430d214/kernel_configs-86c110801e8443207d93837dc53554c59f26ccf5a1a04c352ea7e8587c82d89e/default/cache.json @@ -0,0 +1,25 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_cumsum_fwd_kernel)", + "total_bench_time_s": 7.361271619796753, + "evaluated_configs": 7, + "keys": [ + "chunk_size", + "nheads" + ], + "cache": { + "('256', '128', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32')": "BLOCK_SIZE_H: 2, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('256', '128', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32')": [ + 0.002133406000211835 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_scan_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-3a41493c29184793fa894c5d134a5c291430843f2ca1b798ab5c9e58228d1814/tune_features-3e88866b92d333f029bc0ae6410b8ce764620f4a7514b0062dd8c43c8e63e3e1/kernel_configs-e1d63b4ce9f3ae5e2f38b68d3d8257474338c0a672ac24128b374d342459d7e1/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_scan_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-3a41493c29184793fa894c5d134a5c291430843f2ca1b798ab5c9e58228d1814/tune_features-3e88866b92d333f029bc0ae6410b8ce764620f4a7514b0062dd8c43c8e63e3e1/kernel_configs-e1d63b4ce9f3ae5e2f38b68d3d8257474338c0a672ac24128b374d342459d7e1/default/cache.json new file mode 100755 index 000000000..fb9768114 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_scan_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-3a41493c29184793fa894c5d134a5c291430843f2ca1b798ab5c9e58228d1814/tune_features-3e88866b92d333f029bc0ae6410b8ce764620f4a7514b0062dd8c43c8e63e3e1/kernel_configs-e1d63b4ce9f3ae5e2f38b68d3d8257474338c0a672ac24128b374d342459d7e1/default/cache.json @@ -0,0 +1,31 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_scan:_chunk_scan_fwd_kernel)", + "total_bench_time_s": 22.759257316589355, + "evaluated_configs": 11, + "keys": [ + "chunk_size", + "hdim", + "dstate", + "IS_CAUSAL" + ], + "cache": { + "('256', '64', '128', 'True', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16')": "BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", + "('256', '64', '128', 'True', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32')": "BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, num_warps: 2, num_ctas: 1, num_stages: 5, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('256', '64', '128', 'True', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16')": [ + 0.014240000396966934 + ], + "('256', '64', '128', 'True', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32')": [ + 0.8048959970474243 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": false + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_scan_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-3a41493c29184793fa894c5d134a5c291430843f2ca1b798ab5c9e58228d1814/tune_features-3e88866b92d333f029bc0ae6410b8ce764620f4a7514b0062dd8c43c8e63e3e1/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_scan_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-3a41493c29184793fa894c5d134a5c291430843f2ca1b798ab5c9e58228d1814/tune_features-3e88866b92d333f029bc0ae6410b8ce764620f4a7514b0062dd8c43c8e63e3e1/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json new file mode 100755 index 000000000..dd9c29f78 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_scan_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-3a41493c29184793fa894c5d134a5c291430843f2ca1b798ab5c9e58228d1814/tune_features-3e88866b92d333f029bc0ae6410b8ce764620f4a7514b0062dd8c43c8e63e3e1/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json @@ -0,0 +1,27 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_scan:_chunk_scan_fwd_kernel)", + "total_bench_time_s": 15278.822125434875, + "evaluated_configs": 2625, + "keys": [ + "chunk_size", + "hdim", + "dstate", + "IS_CAUSAL" + ], + "cache": { + "('256', '64', '128', 'True', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16')": "BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 16, num_warps: 2, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('256', '64', '128', 'True', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16')": [ + 0.014237518422305584 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-339ef229a46cc5e4fefcebbabe32af549b053e9d045b9c4c60da297149a339c9/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-339ef229a46cc5e4fefcebbabe32af549b053e9d045b9c4c60da297149a339c9/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json new file mode 100755 index 000000000..010c85ff2 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-339ef229a46cc5e4fefcebbabe32af549b053e9d045b9c4c60da297149a339c9/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json @@ -0,0 +1,26 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_state_fwd_kernel)", + "total_bench_time_s": 5.0212812423706055, + "evaluated_configs": 9, + "keys": [ + "hdim", + "dstate", + "chunk_size" + ], + "cache": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.int32')": "BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, num_warps: 2, num_ctas: 1, num_stages: 5, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.int32')": [ + 0.009247999638319016 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": false + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-339ef229a46cc5e4fefcebbabe32af549b053e9d045b9c4c60da297149a339c9/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-339ef229a46cc5e4fefcebbabe32af549b053e9d045b9c4c60da297149a339c9/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json new file mode 100755 index 000000000..68505b261 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-339ef229a46cc5e4fefcebbabe32af549b053e9d045b9c4c60da297149a339c9/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json @@ -0,0 +1,26 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_state_fwd_kernel)", + "total_bench_time_s": 9348.028031349182, + "evaluated_configs": 2625, + "keys": [ + "hdim", + "dstate", + "chunk_size" + ], + "cache": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.int32')": "BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 64, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.int32')": [ + 0.003924777265638113 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json new file mode 100755 index 000000000..ee569dcb6 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json @@ -0,0 +1,8 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_state_varlen_kernel)", + "total_bench_time_s": 0.0, + "evaluated_configs": 0, + "keys": null, + "cache": {}, + "timings": {} +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json new file mode 100755 index 000000000..a81672d35 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-a6c5e7946f5d4b0ba6fa79217784e3780477be6b4708bab85d511e2f96fb9381/default/cache.json @@ -0,0 +1,30 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_state_varlen_kernel)", + "total_bench_time_s": 17.040932178497314, + "evaluated_configs": 9, + "keys": [ + "hdim", + "dstate", + "chunk_size" + ], + "cache": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32', 'torch.bfloat16')": "BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, num_warps: 2, num_ctas: 1, num_stages: 5, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16')": "BLOCK_SIZE_M: 64, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, num_warps: 2, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32', 'torch.bfloat16')": [ + 0.009184000082314014 + ], + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32', 'torch.bfloat16', 'torch.bfloat16')": [ + 0.009184000082314014 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": false + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json new file mode 100755 index 000000000..06f0a4220 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_chunk_state_varlen_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-f10105bbcf94b3788568aecfef8eb69570d7757afd57bef99faf7bf930a4edcf/tune_features-a17bcb1c348fee486b4e400e9ec475828d4f0d3118d72067b1bc6f94903360fa/kernel_configs-31086bbabdaa5bbed7ee80f8c2feb8195925fe0fe23a8fdfe525b114e663bdea/default/cache.json @@ -0,0 +1,26 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_chunk_state:_chunk_state_varlen_kernel)", + "total_bench_time_s": 19485.390374183655, + "evaluated_configs": 2625, + "keys": [ + "hdim", + "dstate", + "chunk_size" + ], + "cache": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32', 'torch.bfloat16')": "BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 16, BLOCK_SIZE_K: 16, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('64', '128', '256', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32', 'torch.bfloat16')": [ + NaN + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/_selective_scan_update_kernel/autotune_config-90178d0ab8e71db9cd16710d562763dd010643f28cd21980d5064c3ab782ecaa/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-12ef9e4125a78d954cad03c22e7b626a75d6e484131a7b8653f8b7d84d9f78f3/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-4452dd34c8d5c1eade558a6589c89cd1205e0da4d4ef8a72ee7c4c702061e9ba/default/cache.json similarity index 69% rename from ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/_selective_scan_update_kernel/autotune_config-90178d0ab8e71db9cd16710d562763dd010643f28cd21980d5064c3ab782ecaa/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json rename to g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-12ef9e4125a78d954cad03c22e7b626a75d6e484131a7b8653f8b7d84d9f78f3/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-4452dd34c8d5c1eade558a6589c89cd1205e0da4d4ef8a72ee7c4c702061e9ba/default/cache.json index d6bd3e752..466963e92 100755 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/_selective_scan_update_kernel/autotune_config-90178d0ab8e71db9cd16710d562763dd010643f28cd21980d5064c3ab782ecaa/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-12ef9e4125a78d954cad03c22e7b626a75d6e484131a7b8653f8b7d84d9f78f3/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-4452dd34c8d5c1eade558a6589c89cd1205e0da4d4ef8a72ee7c4c702061e9ba/default/cache.json @@ -1,7 +1,7 @@ { - "signature": "JITFunction(ibm_triton_lib.kernels.mamba_ssm:_selective_scan_update_kernel)", - "total_bench_time_s": 113.2074065208435, - "evaluated_configs": 75, + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.mamba_ssm:_selective_scan_update_kernel)", + "total_bench_time_s": 201.7921507358551, + "evaluated_configs": 105, "keys": [ "dstate", "BLOCK_SIZE_DSTATE", @@ -9,11 +9,11 @@ "nheads_ngroups_ratio" ], "cache": { - "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE_M: 16, num_warps: 4, num_ctas: 1, num_stages: 6, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE_M: 64, num_warps: 2, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" }, "timings": { "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": [ - 0.0050251600332558155 + 0.05485290288925171 ] }, "timings_data": { diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-21ff5d19d1819793851ad7c7a60e8f4d7bd7bc84238d0302676bb9e213122e34/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-21ff5d19d1819793851ad7c7a60e8f4d7bd7bc84238d0302676bb9e213122e34/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json new file mode 100755 index 000000000..c7fb158cf --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-21ff5d19d1819793851ad7c7a60e8f4d7bd7bc84238d0302676bb9e213122e34/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json @@ -0,0 +1,27 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.mamba_ssm:_selective_scan_update_kernel)", + "total_bench_time_s": 154.3796603679657, + "evaluated_configs": 75, + "keys": [ + "dstate", + "BLOCK_SIZE_DSTATE", + "dim", + "nheads_ngroups_ratio" + ], + "cache": { + "('128', '128', '64', '128')": "BLOCK_SIZE_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('128', '128', '64', '128')": [ + 1.7349423170089722 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_state_passing_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-55db57c88b8fd2c2a9e9560aeb5afd5b585cf3507fa5eed7a0909f4d26b7cd86/tune_features-c5d4b45934fe1d9c636d8b0b8f49b5a26c5fc7064fb2bda916fe2743b77fcdc1/kernel_configs-68916ac9231d70c9dfa4b1081268470f5b25a8dbabb73d3818ba7e74c7fdc03c/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_state_passing_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-55db57c88b8fd2c2a9e9560aeb5afd5b585cf3507fa5eed7a0909f4d26b7cd86/tune_features-c5d4b45934fe1d9c636d8b0b8f49b5a26c5fc7064fb2bda916fe2743b77fcdc1/kernel_configs-68916ac9231d70c9dfa4b1081268470f5b25a8dbabb73d3818ba7e74c7fdc03c/default/cache.json new file mode 100755 index 000000000..634fae182 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_state_passing_fwd_kernel/autotune_config-215d0c7082adf7c6c8ae2a767088f42b44e6432715b0c6760f5f8e5d4e8371ff/code_version-55db57c88b8fd2c2a9e9560aeb5afd5b585cf3507fa5eed7a0909f4d26b7cd86/tune_features-c5d4b45934fe1d9c636d8b0b8f49b5a26c5fc7064fb2bda916fe2743b77fcdc1/kernel_configs-68916ac9231d70c9dfa4b1081268470f5b25a8dbabb73d3818ba7e74c7fdc03c/default/cache.json @@ -0,0 +1,28 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_state_passing:_state_passing_fwd_kernel)", + "total_bench_time_s": 6.713695287704468, + "evaluated_configs": 6, + "keys": [ + "dim" + ], + "cache": { + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32')": "BLOCK_SIZE: 2048, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE: 512, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32')": [ + 0.009664000011980534 + ], + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32')": [ + 0.1367039978504181 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": false + } +} \ No newline at end of file diff --git a/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_state_passing_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-55db57c88b8fd2c2a9e9560aeb5afd5b585cf3507fa5eed7a0909f4d26b7cd86/tune_features-c5d4b45934fe1d9c636d8b0b8f49b5a26c5fc7064fb2bda916fe2743b77fcdc1/kernel_configs-c4fc6831bf929bccf1df2dabf2b7a316d7b0f7d0a3da7ec749b2f343f3ffe760/default/cache.json b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_state_passing_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-55db57c88b8fd2c2a9e9560aeb5afd5b585cf3507fa5eed7a0909f4d26b7cd86/tune_features-c5d4b45934fe1d9c636d8b0b8f49b5a26c5fc7064fb2bda916fe2743b77fcdc1/kernel_configs-c4fc6831bf929bccf1df2dabf2b7a316d7b0f7d0a3da7ec749b2f343f3ffe760/default/cache.json new file mode 100755 index 000000000..4f831cc77 --- /dev/null +++ b/g4_tuning_data/dejavu_0.7/triton_3.3.1/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_state_passing_fwd_kernel/autotune_config-bef61f0485b4347899c813bd65c9c1d763e62f3d6b5fda018baf600097187c0a/code_version-55db57c88b8fd2c2a9e9560aeb5afd5b585cf3507fa5eed7a0909f4d26b7cd86/tune_features-c5d4b45934fe1d9c636d8b0b8f49b5a26c5fc7064fb2bda916fe2743b77fcdc1/kernel_configs-c4fc6831bf929bccf1df2dabf2b7a316d7b0f7d0a3da7ec749b2f343f3ffe760/default/cache.json @@ -0,0 +1,28 @@ +{ + "signature": "JITFunction(vllm.model_executor.layers.mamba.ops.ssd_state_passing:_state_passing_fwd_kernel)", + "total_bench_time_s": 607.0304324626923, + "evaluated_configs": 168, + "keys": [ + "dim" + ], + "cache": { + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32')": "BLOCK_SIZE: 512, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE: 512, num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" + }, + "timings": { + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.int32')": [ + 0.0030820679385215044 + ], + "('8192', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.bfloat16', 'torch.int32')": [ + 0.13190822303295135 + ] + }, + "timings_data": { + "labels": [ + "ms" + ], + "rep_t_ms": 100, + "warmup_t_ms": 25, + "cuda_graphs": true + } +} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/__init__.py b/ibm-triton-lib/ibm_triton_lib/kernels/__init__.py index a78522fc2..722e0507b 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/__init__.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/__init__.py @@ -67,5 +67,11 @@ def ConfigSpace( ) from .triton_unified_attention import unified_attention +from .triton_unified_attention_simple import unified_attention as unified_attention_simple +from .triton_unified_newtiles import unified_attention as unified_attention_newtiles +from .triton_unified_attention_tuned import unified_attention as unified_attention_tuned +from .triton_unified_grid import unified_attention as unified_attention_grid from .mamba_ssm import selective_state_update + +# from .fused_moe import fused_moe diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=62,N=256,device_name=NVIDIA_H100_80GB_HBM3.json b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=62,N=256,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..147a83660 --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=62,N=256,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=62,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=62,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..a01e9c317 --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=62,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=72,N=384,device_name=NVIDIA_H100_80GB_HBM3.json b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=72,N=384,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..a7cfd175d --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=72,N=384,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=72,N=768,device_name=NVIDIA_H100_80GB_HBM3.json b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=72,N=768,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..3caae02cb --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/configs/E=72,N=768,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_A100-SXM4-80GB/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_A100-SXM4-80GB/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json deleted file mode 100755 index 19e6fc76c..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_A100-SXM4-80GB/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json +++ /dev/null @@ -1,110 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_flash_attention:attn_fwd)", - "total_bench_time_s": 211706.17069911957, - "evaluated_configs": 450, - "keys": [ - "HQ", - "HK", - "IS_CAUSAL", - "dropout_p", - "BLOCK_DMODEL", - "stride_qz", - "stride_qh", - "stride_qm", - "stride_qk", - "stride_kz", - "stride_kh", - "stride_kn", - "stride_kk", - "stride_vz", - "stride_vh", - "stride_vn", - "stride_vk", - "stride_oz", - "stride_oh", - "stride_om", - "stride_on", - "stride_bz", - "stride_bh", - "stride_bm", - "stride_bn", - "stride_az", - "stride_ah", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "VARLEN", - "ACTUAL_BLOCK_DMODEL" - ], - "cache": { - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.005401020869612694 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.005471085663884878 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.0075958045199513435 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.007605006452649832 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.011812349781394005 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.011950820684432983 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.019297460094094276 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.017475301399827003 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.038042228668928146 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.038091544061899185 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.10096532106399536 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.09481953084468842 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.2949035167694092 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.29237720370292664 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.9560787677764893 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-a70f97e8b3e7aaf9f4a4f7e850b935d2d1b3ad8cd6ad1d0843bb426e13694ae9/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-a70f97e8b3e7aaf9f4a4f7e850b935d2d1b3ad8cd6ad1d0843bb426e13694ae9/default/cache.json deleted file mode 100755 index a7b0d4282..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-a70f97e8b3e7aaf9f4a4f7e850b935d2d1b3ad8cd6ad1d0843bb426e13694ae9/default/cache.json +++ /dev/null @@ -1,110 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_flash_attention:attn_fwd)", - "total_bench_time_s": 86841.6919836998, - "evaluated_configs": 240, - "keys": [ - "HQ", - "HK", - "IS_CAUSAL", - "dropout_p", - "BLOCK_DMODEL", - "stride_qz", - "stride_qh", - "stride_qm", - "stride_qk", - "stride_kz", - "stride_kh", - "stride_kn", - "stride_kk", - "stride_vz", - "stride_vh", - "stride_vn", - "stride_vk", - "stride_oz", - "stride_oh", - "stride_om", - "stride_on", - "stride_bz", - "stride_bh", - "stride_bm", - "stride_bn", - "stride_az", - "stride_ah", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "VARLEN", - "ACTUAL_BLOCK_DMODEL" - ], - "cache": { - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 32, BLOCK_N: 32, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 128, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.0036645731888711452 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '32', '32', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.0036076440010219812 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.00487453443929553 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '64', '64', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.0048555657267570496 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.006982282269746065 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.006992792245000601 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.010331092402338982 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '256', '256', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.010227189399302006 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.015056964010000229 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.014920394867658615 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.04663630574941635 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.04339428246021271 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.1311214417219162 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.12436506152153015 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.39030927419662476 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/rocm_6.3.1/gpu_AMD_Instinct_MI250X_MI250/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/rocm_6.3.1/gpu_AMD_Instinct_MI250X_MI250/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json deleted file mode 100755 index a7669881a..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.2.0/rocm_6.3.1/gpu_AMD_Instinct_MI250X_MI250/attn_fwd/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-0a43fd896fb3d6519678247aeba94610b596378a3138e88995ca3569d6672a96/tune_features-df62f53ce178f143b59631de953c946e43811ff1b34cd71e422dfdf14ac35bb9/kernel_configs-1f316f0fbddd51d950280abb53d67b60494f0cf2c02eeb1b551b0356a33a7dc8/default/cache.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_flash_attention:attn_fwd)", - "total_bench_time_s": 86906.62447404861, - "evaluated_configs": 450, - "keys": [ - "HQ", - "HK", - "IS_CAUSAL", - "dropout_p", - "BLOCK_DMODEL", - "stride_qz", - "stride_qh", - "stride_qm", - "stride_qk", - "stride_kz", - "stride_kh", - "stride_kn", - "stride_kk", - "stride_vz", - "stride_vh", - "stride_vn", - "stride_vk", - "stride_oz", - "stride_oh", - "stride_om", - "stride_on", - "stride_bz", - "stride_bh", - "stride_bm", - "stride_bn", - "stride_az", - "stride_ah", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "VARLEN", - "ACTUAL_BLOCK_DMODEL" - ], - "cache": { - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 16, BLOCK_N: 16, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: True, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 64, BLOCK_N: 64, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": "BLOCK_M: 256, BLOCK_N: 128, PRE_LOAD_V: False, GRID_CU_MULTIP: 2, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.004207286983728409 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '16', '16', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.004182395525276661 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.01809287816286087 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '128', '128', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.017839614301919937 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.09088581800460815 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '512', '512', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.088987797498703 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.23396557569503784 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '1024', '1024', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.23347480595111847 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.6691922545433044 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '2048', '2048', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 0.6695101261138916 - ], - "('32', '32', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '4096', '1', '0', '128', '1', '4096', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 2.025791645050049 - ], - "('32', '8', 'True', '0.0', '128', '0', '128', '4096', '1', '0', '128', '1024', '1', '0', '128', '1', '1024', '0', '128', '4096', '1', '0', '0', '0', '0', '0', '0', '4096', '4096', 'True', '128', 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float16', 'torch.int32', 'torch.int32', 'torch.int32')": [ - 2.01798415184021 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json deleted file mode 100755 index 60a6d6935..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/_selective_scan_update_kernel/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-669be673bf919df57c10083821a49ac5e1e5629db08d0501c1c298603ad4ecb8/tune_features-93313ae47bf85925b0b3b8a0af710ff4a94421cf3e6ebd1a348e74369ddc45e8/kernel_configs-85691372c5ea21c12337d65667ec842af16b51057ec486e7af706471f7a50309/default/cache.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.mamba_ssm:_selective_scan_update_kernel)", - "total_bench_time_s": 58.42541313171387, - "evaluated_configs": 75, - "keys": [ - "dstate", - "BLOCK_SIZE_DSTATE", - "dim", - "nheads_ngroups_ratio" - ], - "cache": { - "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": "BLOCK_SIZE_M: 8, num_warps: 2, num_ctas: 1, num_stages: 6, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('128', '128', '64', '128', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32')": [ - 0.003274054965004325 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json deleted file mode 100755 index 04eb1f234..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json +++ /dev/null @@ -1,347 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_unified_attention:kernel_unified_attention_2d)", - "total_bench_time_s": 34544.99443292618, - "evaluated_configs": 540, - "keys": [ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3" - ], - "cache": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003466148627921939 - ], - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003575095208361745 - ], - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004993442911654711 - ], - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006109926383942366 - ], - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03988393768668175 - ], - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.09943539649248123 - ], - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3283151388168335 - ], - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.0377004146575928 - ], - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033776038326323032 - ], - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003488453570753336 - ], - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033901487477123737 - ], - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0032401704229414463 - ], - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004394480027258396 - ], - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004883989226073027 - ], - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0045789312571287155 - ], - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006259772460907698 - ], - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010929320007562637 - ], - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.040549296885728836 - ], - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.02016238309442997 - ], - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1051921397447586 - ], - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03749670833349228 - ], - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3411431908607483 - ], - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0701025053858757 - ], - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.0497854948043823 - ], - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0034944734070450068 - ], - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0042336732149124146 - ], - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005933090578764677 - ], - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.026846082881093025 - ], - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07565699517726898 - ], - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.2685732841491699 - ], - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8566849827766418 - ], - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003527216147631407 - ], - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004583046771585941 - ], - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0060236589051783085 - ], - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.026979871094226837 - ], - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08126690983772278 - ], - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.2932415306568146 - ], - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8659728765487671 - ], - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00306075531989336 - ], - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0034781373105943203 - ], - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003616524860262871 - ], - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0030675148591399193 - ], - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0038118616212159395 - ], - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003134604310616851 - ], - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0055700079537928104 - ], - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.009849821217358112 - ], - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014783395454287529 - ], - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04928915575146675 - ], - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.15255023539066315 - ], - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.013137963600456715 - ], - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.4398653507232666 - ], - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.4163719415664673 - ], - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033607585355639458 - ], - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0038107747677713633 - ], - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004322108346968889 - ], - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033715730533003807 - ], - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004160675685852766 - ], - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004942106083035469 - ], - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00334966741502285 - ], - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0050212424248456955 - ], - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007804282940924168 - ], - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007798833306878805 - ], - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014028973877429962 - ], - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03204701468348503 - ], - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08394649624824524 - ], - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08103202283382416 - ], - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.23096241056919098 - ], - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006906270515173674 - ], - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.23079754412174225 - ], - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.7025490999221802 - ], - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6989444494247437 - ], - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 2.3537752628326416 - ], - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004250869620591402 - ], - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005911743268370628 - ], - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011380953714251518 - ], - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.05582933872938156 - ], - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.16943588852882385 - ], - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.4909878969192505 - ], - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.5911381244659424 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json deleted file mode 100755 index 1a8388dae..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json +++ /dev/null @@ -1,387 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_unified_attention:kernel_unified_attention_2d)", - "total_bench_time_s": 67657.00523352623, - "evaluated_configs": 540, - "keys": [ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3" - ], - "cache": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 64, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 64, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 32, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 128, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0034347970504313707 - ], - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0035579479299485683 - ], - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00523252971470356 - ], - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006011391524225473 - ], - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.023085465654730797 - ], - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08206301927566528 - ], - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3279804289340973 - ], - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.1915172338485718 - ], - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033755453769117594 - ], - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003468221053481102 - ], - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00334682478569448 - ], - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0035435776226222515 - ], - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004342962987720966 - ], - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00496680336073041 - ], - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004553888458758593 - ], - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007391158025711775 - ], - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011154169216752052 - ], - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04036085680127144 - ], - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.019932862371206284 - ], - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08319558948278427 - ], - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03744187951087952 - ], - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3325899839401245 - ], - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.06968305259943008 - ], - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.184262990951538 - ], - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003470577532425523 - ], - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004544882569462061 - ], - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00577146140858531 - ], - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.022477485239505768 - ], - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04180074483156204 - ], - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.16259081661701202 - ], - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6357383131980896 - ], - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0034817454870790243 - ], - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00421161251142621 - ], - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00583713548257947 - ], - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.02271271124482155 - ], - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07548002898693085 - ], - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.17187528312206268 - ], - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6434140801429749 - ], - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033293836750090122 - ], - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003431792138144374 - ], - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003589486936107278 - ], - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003379078349098563 - ], - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0041108024306595325 - ], - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033878879621624947 - ], - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006029331590980291 - ], - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008353302255272865 - ], - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.013032807968556881 - ], - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04468222334980965 - ], - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1537272334098816 - ], - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01300885435193777 - ], - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.48241302371025085 - ], - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.7054001092910767 - ], - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033725856337696314 - ], - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0037622733507305384 - ], - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004256599582731724 - ], - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00334113254211843 - ], - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004093301948159933 - ], - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004860257264226675 - ], - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003374352352693677 - ], - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005010899156332016 - ], - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007828187197446823 - ], - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007898394018411636 - ], - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014706183224916458 - ], - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03305657580494881 - ], - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08440500497817993 - ], - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08125007152557373 - ], - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.2514193058013916 - ], - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006724500097334385 - ], - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.22513994574546814 - ], - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8429425954818726 - ], - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6514143943786621 - ], - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 3.03377103805542 - ], - "('1', '8', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033735581673681736 - ], - "('1', '16', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003457766491919756 - ], - "('1', '32', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003451892174780369 - ], - "('1', '64', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004423843696713448 - ], - "('1', '128', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004672772716730833 - ], - "('1', '256', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006518691312521696 - ], - "('1', '512', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010816759429872036 - ], - "('1', '1024', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01876869797706604 - ], - "('1', '2048', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03477397561073303 - ], - "('1', '4096', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07260602712631226 - ], - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004245477728545666 - ], - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006100499536842108 - ], - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008639966137707233 - ], - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04726530611515045 - ], - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.14509893953800201 - ], - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.4709869623184204 - ], - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.6025410890579224 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json deleted file mode 100755 index 04eb1f234..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/cuda_12.4/gpu_NVIDIA_H100_80GB_HBM3/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json +++ /dev/null @@ -1,347 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_unified_attention:kernel_unified_attention_2d)", - "total_bench_time_s": 34544.99443292618, - "evaluated_configs": 540, - "keys": [ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3" - ], - "cache": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 8, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003466148627921939 - ], - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003575095208361745 - ], - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004993442911654711 - ], - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006109926383942366 - ], - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03988393768668175 - ], - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.09943539649248123 - ], - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3283151388168335 - ], - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.0377004146575928 - ], - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033776038326323032 - ], - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003488453570753336 - ], - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033901487477123737 - ], - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0032401704229414463 - ], - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004394480027258396 - ], - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004883989226073027 - ], - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0045789312571287155 - ], - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006259772460907698 - ], - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010929320007562637 - ], - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.040549296885728836 - ], - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.02016238309442997 - ], - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1051921397447586 - ], - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03749670833349228 - ], - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3411431908607483 - ], - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0701025053858757 - ], - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.0497854948043823 - ], - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0034944734070450068 - ], - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0042336732149124146 - ], - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005933090578764677 - ], - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.026846082881093025 - ], - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07565699517726898 - ], - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.2685732841491699 - ], - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8566849827766418 - ], - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003527216147631407 - ], - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004583046771585941 - ], - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0060236589051783085 - ], - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.026979871094226837 - ], - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08126690983772278 - ], - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.2932415306568146 - ], - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8659728765487671 - ], - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00306075531989336 - ], - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0034781373105943203 - ], - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003616524860262871 - ], - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0030675148591399193 - ], - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0038118616212159395 - ], - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.003134604310616851 - ], - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0055700079537928104 - ], - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.009849821217358112 - ], - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014783395454287529 - ], - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04928915575146675 - ], - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.15255023539066315 - ], - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.013137963600456715 - ], - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.4398653507232666 - ], - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.4163719415664673 - ], - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033607585355639458 - ], - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0038107747677713633 - ], - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004322108346968889 - ], - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0033715730533003807 - ], - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004160675685852766 - ], - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004942106083035469 - ], - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00334966741502285 - ], - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0050212424248456955 - ], - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007804282940924168 - ], - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007798833306878805 - ], - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014028973877429962 - ], - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03204701468348503 - ], - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08394649624824524 - ], - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08103202283382416 - ], - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.23096241056919098 - ], - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006906270515173674 - ], - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.23079754412174225 - ], - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.7025490999221802 - ], - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6989444494247437 - ], - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 2.3537752628326416 - ], - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.004250869620591402 - ], - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005911743268370628 - ], - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011380953714251518 - ], - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.05582933872938156 - ], - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.16943588852882385 - ], - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.4909878969192505 - ], - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.5911381244659424 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json deleted file mode 100755 index db665c68f..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-2e68df1b2ccc61cd52696753033f640191f6d65a4eba454efdb10ac09cee2f95/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json +++ /dev/null @@ -1,347 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_unified_attention:kernel_unified_attention_2d)", - "total_bench_time_s": 72002.96068787575, - "evaluated_configs": 540, - "keys": [ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3" - ], - "cache": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 32, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 32, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006003436166793108 - ], - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006077692378312349 - ], - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0066948747262358665 - ], - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008714776486158371 - ], - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03953208029270172 - ], - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08529671281576157 - ], - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.26893165707588196 - ], - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.7998318672180176 - ], - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00574119808152318 - ], - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006026116665452719 - ], - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005752653814852238 - ], - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00608863914385438 - ], - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006379257421940565 - ], - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006695704068988562 - ], - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007991316728293896 - ], - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00874169822782278 - ], - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.021478423848748207 - ], - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.038848876953125 - ], - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03919544070959091 - ], - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08279953896999359 - ], - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07393984496593475 - ], - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.26520422101020813 - ], - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.143253892660141 - ], - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8069456219673157 - ], - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006098074372857809 - ], - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006664188578724861 - ], - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008316880092024803 - ], - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.032703448086977005 - ], - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07349277287721634 - ], - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.17093537747859955 - ], - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6028901934623718 - ], - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006040927022695541 - ], - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006674066185951233 - ], - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008359000086784363 - ], - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.033145882189273834 - ], - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0726323127746582 - ], - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.16725540161132812 - ], - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6085386872291565 - ], - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00583583302795887 - ], - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00593462772667408 - ], - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006117511540651321 - ], - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0059266164898872375 - ], - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006205248646438122 - ], - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005945528391748667 - ], - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0069659799337387085 - ], - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010612651705741882 - ], - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01373966969549656 - ], - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04602960869669914 - ], - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.12627318501472473 - ], - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014789633452892303 - ], - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3502292037010193 - ], - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.0954514741897583 - ], - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005718982312828302 - ], - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006129336543381214 - ], - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006283498369157314 - ], - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0057284715585410595 - ], - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0061799646355211735 - ], - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007406504824757576 - ], - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005748743191361427 - ], - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006614300422370434 - ], - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008334673009812832 - ], - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010265326127409935 - ], - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.015284508466720581 - ], - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03939511626958847 - ], - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07506544888019562 - ], - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08072267472743988 - ], - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1980127990245819 - ], - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011478512547910213 - ], - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.21105918288230896 - ], - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.5597497224807739 - ], - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.5454477071762085 - ], - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.9615601301193237 - ], - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00629243953153491 - ], - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008062037639319897 - ], - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01422079000622034 - ], - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0551898293197155 - ], - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.14126861095428467 - ], - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3813389539718628 - ], - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.2401379346847534 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json deleted file mode 100755 index 5e025265d..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-356e536ec49f15d95d2a2610df8277796c9330d647b924736ed5c106312d4227/code_version-5929ad03b9fa9764bf7161e5d9bf068628b7668ea2c33d6b1c3d10ebc8b7a0a6/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json +++ /dev/null @@ -1,387 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_unified_attention:kernel_unified_attention_2d)", - "total_bench_time_s": 81407.73767566681, - "evaluated_configs": 540, - "keys": [ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3" - ], - "cache": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 256, BLOCK_M: 32, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 256, BLOCK_M: 32, num_warps: 4, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 32, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0060075013898313046 - ], - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006072512362152338 - ], - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00672190822660923 - ], - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008806715719401836 - ], - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04485657438635826 - ], - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.09946674853563309 - ], - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.35092800855636597 - ], - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.324418544769287 - ], - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0057691833935678005 - ], - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006055567879229784 - ], - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005804183427244425 - ], - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006106226239353418 - ], - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006440665107220411 - ], - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006741056218743324 - ], - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007889878936111927 - ], - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008913432247936726 - ], - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.021346861496567726 - ], - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04106005281209946 - ], - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03879227116703987 - ], - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0952981486916542 - ], - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0731193870306015 - ], - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3475594222545624 - ], - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.14168496429920197 - ], - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.324677586555481 - ], - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0060554975643754005 - ], - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006669852416962385 - ], - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008174276910722256 - ], - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03536117449402809 - ], - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07847916334867477 - ], - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.18417692184448242 - ], - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6875757575035095 - ], - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006102146580815315 - ], - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006687485612928867 - ], - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0084276357665658 - ], - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03678948059678078 - ], - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07642015814781189 - ], - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.18387676775455475 - ], - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6868319511413574 - ], - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005820533260703087 - ], - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0059619504027068615 - ], - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006105729844421148 - ], - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005979663692414761 - ], - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0062386938370764256 - ], - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005969700403511524 - ], - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007005539257079363 - ], - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011318272911012173 - ], - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01767335832118988 - ], - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.048929426819086075 - ], - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1755041629076004 - ], - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01716405153274536 - ], - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.5103733539581299 - ], - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.8636406660079956 - ], - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0057952022179961205 - ], - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006148397456854582 - ], - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006287233904004097 - ], - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005749743431806564 - ], - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006230741273611784 - ], - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007458249572664499 - ], - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00579081941395998 - ], - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006615426391363144 - ], - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00870793592184782 - ], - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01026986539363861 - ], - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.015668710693717003 - ], - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.040304314345121384 - ], - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0959310457110405 - ], - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0849064514040947 - ], - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.2615358829498291 - ], - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011502742767333984 - ], - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.25011205673217773 - ], - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8817259073257446 - ], - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.7242566347122192 - ], - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 3.2800190448760986 - ], - "('1', '8', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00581999821588397 - ], - "('1', '16', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0058884210884571075 - ], - "('1', '32', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0058985608629882336 - ], - "('1', '64', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0065222084522247314 - ], - "('1', '128', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008244817145168781 - ], - "('1', '256', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011564841493964195 - ], - "('1', '512', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.021496908739209175 - ], - "('1', '1024', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.038903381675481796 - ], - "('1', '2048', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07334144413471222 - ], - "('1', '4096', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1418607085943222 - ], - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006298307329416275 - ], - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008144522085785866 - ], - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014301695860922337 - ], - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.06052287295460701 - ], - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1740308254957199 - ], - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.4944685995578766 - ], - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.7257815599441528 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-eff99677f7c0c1715ee99c9f1c8cf2a597630dd934ea82c3a3f4cdcd26d2e859/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json b/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-eff99677f7c0c1715ee99c9f1c8cf2a597630dd934ea82c3a3f4cdcd26d2e859/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json deleted file mode 100755 index db665c68f..000000000 --- a/ibm-triton-lib/ibm_triton_lib/kernels/dejavu_data/dejavu_0.7/triton_3.3.0/rocm_torch_6.2.41134-65d174c3e/gpu_AMD_Instinct_MI300X/kernel_unified_attention_2d/autotune_config-eff99677f7c0c1715ee99c9f1c8cf2a597630dd934ea82c3a3f4cdcd26d2e859/code_version-67c5278a57a01b9e312f17a648cae5031730e47c496c02f3a23832e14fc93b14/tune_features-1951755092d3da5141f4b15aeee3b864a29766ecdb441f9f148e955fcfae08c6/kernel_configs-5519d9b1918ec274a537269f5fbd0ad024b0e4043a66d66c7a04f6cac9f334e4/default/cache.json +++ /dev/null @@ -1,347 +0,0 @@ -{ - "signature": "JITFunction(ibm_triton_lib.kernels.triton_unified_attention:kernel_unified_attention_2d)", - "total_bench_time_s": 72002.96068787575, - "evaluated_configs": 540, - "keys": [ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3" - ], - "cache": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 32, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 128, BLOCK_M: 16, num_warps: 8, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 32, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 32, BLOCK_M: 16, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 16, BLOCK_M: 64, num_warps: 4, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None", - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": "BLOCK_N: 64, BLOCK_M: 64, num_warps: 2, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None" - }, - "timings": { - "('16', '16', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006003436166793108 - ], - "('32', '32', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006077692378312349 - ], - "('64', '64', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0066948747262358665 - ], - "('128', '128', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008714776486158371 - ], - "('512', '512', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03953208029270172 - ], - "('1024', '1024', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08529671281576157 - ], - "('2048', '2048', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.26893165707588196 - ], - "('4096', '4096', '4096', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.7998318672180176 - ], - "('1', '16', '1', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00574119808152318 - ], - "('16', '16', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006026116665452719 - ], - "('1', '32', '1', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005752653814852238 - ], - "('32', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00608863914385438 - ], - "('1', '64', '1', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006379257421940565 - ], - "('64', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006695704068988562 - ], - "('1', '128', '1', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007991316728293896 - ], - "('128', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00874169822782278 - ], - "('1', '512', '1', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.021478423848748207 - ], - "('512', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.038848876953125 - ], - "('1', '1024', '1', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03919544070959091 - ], - "('1024', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08279953896999359 - ], - "('1', '2048', '1', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07393984496593475 - ], - "('2048', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.26520422101020813 - ], - "('1', '4096', '1', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.143253892660141 - ], - "('4096', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.8069456219673157 - ], - "('16', '32', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006098074372857809 - ], - "('32', '64', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006664188578724861 - ], - "('64', '128', '64', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008316880092024803 - ], - "('256', '512', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.032703448086977005 - ], - "('512', '1024', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07349277287721634 - ], - "('1024', '2048', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.17093537747859955 - ], - "('2048', '4096', '2048', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6028901934623718 - ], - "('16', '32', '8', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006040927022695541 - ], - "('32', '64', '16', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006674066185951233 - ], - "('64', '128', '32', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008359000086784363 - ], - "('256', '512', '128', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.033145882189273834 - ], - "('512', '1024', '256', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0726323127746582 - ], - "('1024', '2048', '512', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.16725540161132812 - ], - "('2048', '4096', '1024', '4096', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.6085386872291565 - ], - "('2', '2', '2', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00583583302795887 - ], - "('8', '8', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00593462772667408 - ], - "('16', '16', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006117511540651321 - ], - "('4', '4', '4', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0059266164898872375 - ], - "('32', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006205248646438122 - ], - "('8', '8', '8', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005945528391748667 - ], - "('64', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0069659799337387085 - ], - "('128', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010612651705741882 - ], - "('256', '256', '128', '128', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01373966969549656 - ], - "('512', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.04602960869669914 - ], - "('1024', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.12627318501472473 - ], - "('256', '256', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.014789633452892303 - ], - "('2048', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3502292037010193 - ], - "('4096', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.0954514741897583 - ], - "('1', '2', '1', '2', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005718982312828302 - ], - "('8', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006129336543381214 - ], - "('16', '16', '4', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006283498369157314 - ], - "('1', '4', '1', '4', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0057284715585410595 - ], - "('16', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0061799646355211735 - ], - "('32', '32', '8', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.007406504824757576 - ], - "('1', '8', '1', '8', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.005748743191361427 - ], - "('32', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.006614300422370434 - ], - "('64', '64', '16', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008334673009812832 - ], - "('64', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.010265326127409935 - ], - "('128', '128', '32', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.015284508466720581 - ], - "('256', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.03939511626958847 - ], - "('512', '512', '128', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.07506544888019562 - ], - "('512', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.08072267472743988 - ], - "('1024', '1024', '256', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.1980127990245819 - ], - "('1', '256', '1', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.011478512547910213 - ], - "('1024', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.21105918288230896 - ], - "('2048', '2048', '512', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.5597497224807739 - ], - "('2048', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.5454477071762085 - ], - "('4096', '4096', '1024', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.9615601301193237 - ], - "('16', '32', '16', '16', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.00629243953153491 - ], - "('32', '64', '32', '32', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.008062037639319897 - ], - "('64', '128', '64', '64', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.01422079000622034 - ], - "('256', '512', '256', '256', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.0551898293197155 - ], - "('512', '1024', '512', '512', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.14126861095428467 - ], - "('1024', '2048', '1024', '1024', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 0.3813389539718628 - ], - "('2048', '4096', '2048', '2048', '32', '4', '16', '128', '128', '0', '1', '1')": [ - 1.2401379346847534 - ] - }, - "timings_data": { - "labels": [ - "ms" - ], - "rep_t_ms": 100, - "warmup_t_ms": 25, - "cuda_graphs": true - } -} \ No newline at end of file diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/fused_moe.py b/ibm-triton-lib/ibm_triton_lib/kernels/fused_moe.py new file mode 100644 index 000000000..03230e1f6 --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/fused_moe.py @@ -0,0 +1,1820 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Optional + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops +from vllm.logger import init_logger +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, get_config_quant_dtype) +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + _valid_cutlass_block_scaled_grouped_gemm, + run_cutlass_block_scaled_fused_experts) +# yapf: enable +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + dequant_mxfp4) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils import direct_register_custom_op +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used + +# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled + +logger = init_logger(__name__) + + +@triton.jit +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, + compute_type): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + if use_int4_w4a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ + stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ + offs_bn[None, :] * stride_bsn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ + stride_bsk + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + (offs_bn[None, :] // 2) * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + offs_bn[None, :] * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8 or use_int8_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + + offs_bsn * stride_bsn) + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, + None] + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + if use_fp8_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def invoke_fused_moe_kernel(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[list[int]] = None) -> None: + assert topk_weights is not None or not mul_routed_weight + assert topk_weights is None or topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8 or use_int8_w8a8: + assert B_scale is not None + assert (block_shape is None + or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) + assert (block_shape is None + or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + M = A.size(0) + num_tokens = M * top_k + + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.size(0), + A.size(0) * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( + B.size(1), META['BLOCK_SIZE_N']), ) + + if (use_int8_w8a16 or use_int4_w4a16) and \ + block_shape is not None and block_shape[1] > 0: + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + use_moe_wna16_cuda = should_moe_wna16_use_cuda( + num_valid_tokens=num_tokens, + group_size=block_shape[1], + num_experts=B.size(0), + bit=4 if use_int4_w4a16 else 8) + config = config.copy() + config.update( + get_moe_wna16_block_config(config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"])) + + if use_moe_wna16_cuda: + bit = 4 if use_int4_w4a16 else 8 + ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, expert_ids, + num_tokens_post_padded, top_k, + config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], bit) + return + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + A.size(1), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + else: + config = config.copy() + BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") + if block_shape is not None: + BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], + block_shape[1])) + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + B.size(2), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) + if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) + if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) + if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) + if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) + if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + BLOCK_SIZE_K=BLOCK_SIZE_K, + **config, + ) + + +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name(E: int, + N: int, + dtype: Optional[str], + block_shape: Optional[list[int]] = None) -> str: + device_name = current_platform.get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + block_shape_selector = ("" if not block_shape or not all(block_shape) else + f",block_shape={block_shape}").replace(" ", "") + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 + + +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +@functools.lru_cache +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, +) -> Optional[dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + print("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + if envs.VLLM_ENABLE_FUSED_MOE_CONFIG_HEURISTICS: + config_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") + all_config_files = [f for f in os.listdir(config_folder) if os.path.isfile(os.path.join(config_folder, f))] + applicable_name_part = config_file_path.split(",device")[1] + all_applicable_files = [f for f in all_config_files if applicable_name_part in f] + # N given E + available_E = list(set([int(f.split("E=")[1].split(",N=")[0]) for f in all_applicable_files])) + next_best_e = min(available_E, key=lambda x: abs(x - E)) + all_applicable_n = [f for f in all_applicable_files if f"E={next_best_e}" in f] + available_N_given_e = list(set([int(f.split("N=")[1].split(",device")[0]) for f in all_applicable_n])) + next_best_n = min(available_N_given_e, key=lambda x: abs(x - N)) + # E given N + # available_N = list(set([int(f.split("N=")[1].split(",device")[0]) for f in all_applicable_files])) + # next_best_n = min(available_N, key=lambda x: abs(x - N)) + # all_applicable_e = [f for f in all_applicable_files if f"N={next_best_n}" in f] + # available_E_given_n = list(set([int(f.split("E=")[1].split(",N=")[0]) for f in all_applicable_e])) + # next_best_e = min(available_E_given_n, key=lambda x: abs(x - E)) + + fallback_json_file_name = get_config_file_name(next_best_e, next_best_n, dtype, block_shape) + fallback_config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", fallback_json_file_name) + if os.path.exists(fallback_config_file_path): + with open(fallback_config_file_path) as f: + print(("Config file not found at %s. Trying to use next" \ + " best config at %s for MoE layer. Performance" + " might still be sub-optimal."), + config_file_path, fallback_config_file_path) + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available (and heuristics is disabled), + # we will use the default configuration + print( + ("Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s"), config_file_path) + return None + + +def get_moe_wna16_block_config(config: dict[str, + int], use_moe_wna16_cuda: bool, + num_valid_tokens: int, size_k: int, size_n: int, + num_experts: int, group_size: int, + real_top_k: int, block_size_m: int): + if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: + # optimal block config is set + return {} + if not use_moe_wna16_cuda: + # triton moe wna16 kernel + if num_valid_tokens // real_top_k == 1: + # if bs=1, use a smaller BLOCK_SIZE_N + return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} + else: + return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} + else: + # cuda moe wna16 kernel + # set default block_size 128, and increase them when num_blocks + # is too large. + block_size_n = 128 + block_size_k = 128 + if block_size_k <= group_size: + block_size_k = group_size + + num_n_blocks = size_k // block_size_k + num_k_blocks = size_n // block_size_k + num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ + num_experts + if num_valid_tokens // real_top_k <= block_size_m: + num_m_blocks = min(num_m_blocks, num_valid_tokens) + num_blocks = num_m_blocks * num_n_blocks * num_k_blocks + + if size_k % 256 == 0 and num_blocks >= 256 and \ + block_size_k < 256: + block_size_k = 256 + num_blocks = num_blocks // (256 // block_size_k) + + if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ + size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ + num_blocks >= 512: + block_size_k = block_size_k * 2 + num_blocks = num_blocks // 2 + + if num_blocks > 1024: + block_size_n = 256 + num_n_blocks = num_n_blocks // 2 + num_blocks = num_blocks // 2 + + if size_n <= 1024 and num_blocks >= 1024: + # The kernel performance got much better with BLOCK_SIZE_N=1024 + # when num_blocks is large, event when N is small. + # Not sure why, maybe it force the CUDA SM process only one block + # at the same time. + block_size_n = 1024 + + return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} + + +def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, + num_experts: int, bit: int): + return bit == 4 and group_size in [32, 64, 128] and \ + num_valid_tokens / num_experts <= 6 + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, + block_shape: Optional[list[int]] = None, +) -> dict[str, int]: + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] + # num_stages=3 can cause triton.runtime.errors.OutOfResources + # on ROCm, set it to 2 instead. + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 if not current_platform.is_rocm() else 2, + } + elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: + # moe wna16 kernels + # only set BLOCK_SIZE_M + # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later + bit = 4 if dtype == "int4_w4a16" else 8 + use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, + block_shape[1], E, bit) + if use_moe_wna16_cuda: + config = {"BLOCK_SIZE_M": min(16, M)} + elif M <= 20: + config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} + elif M <= 40: + config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} + else: + config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + elif is_marlin: + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break + return {"BLOCK_SIZE_M": block_size_m} + elif M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + return config + + +def try_get_optimal_moe_config( + w1_shape: tuple[int, ...], + w2_shape: tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, + block_shape: Optional[list[int]] = None, + force_default=False, +) -> dict[str, int]: + from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + if dtype == "int4_w4a16": + N = N * 2 + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) + + if configs and not force_default: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin, block_shape) + return config + + +def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + ops.topk_softmax( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_indices + + +def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: + return vllm_topk_softmax + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + indices_type: Optional[torch.dtype] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + M, _ = hidden_states.size() + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device) + token_expert_indices = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + + gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. + + topk_func = dispatch_topk_func() + topk_weights, topk_ids = topk_func(topk_weights, topk_ids, + token_expert_indices, + gating_output_float, renormalize) + + return topk_weights, topk_ids, token_expert_indices + + +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.size(0) + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def get_config_dtype_str( + dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False, + use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" + elif use_mxfp4_w4a4: + return "mxfp4_w4a4" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +# def inplace_fused_experts(hidden_states: torch.Tensor, +# w1: torch.Tensor, +# w2: torch.Tensor, +# topk_weights: torch.Tensor, +# topk_ids: torch.Tensor, +# activation: str = "silu", +# apply_router_weight_on_input: bool = False, +# use_fp8_w8a8: bool = False, +# use_int8_w8a8: bool = False, +# use_int8_w8a16: bool = False, +# use_int4_w4a16: bool = False, +# use_mxfp4_w4a4: bool = False, +# per_channel_quant: bool = False, +# global_num_experts: int = -1, +# expert_map: Optional[torch.Tensor] = None, +# w1_scale: Optional[torch.Tensor] = None, +# w2_scale: Optional[torch.Tensor] = None, +# w1_zp: Optional[torch.Tensor] = None, +# w2_zp: Optional[torch.Tensor] = None, +# a1_scale: Optional[torch.Tensor] = None, +# a2_scale: Optional[torch.Tensor] = None, +# block_shape: Optional[list[int]] = None) -> None: +# fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, +# activation, apply_router_weight_on_input, use_fp8_w8a8, +# use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, +# use_mxfp4_w4a4, per_channel_quant, global_num_experts, +# expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, +# a2_scale, block_shape) +# +# +# def inplace_fused_experts_fake( +# hidden_states: torch.Tensor, +# w1: torch.Tensor, +# w2: torch.Tensor, +# topk_weights: torch.Tensor, +# topk_ids: torch.Tensor, +# activation: str = "silu", +# apply_router_weight_on_input: bool = False, +# use_fp8_w8a8: bool = False, +# use_int8_w8a8: bool = False, +# use_int8_w8a16: bool = False, +# use_int4_w4a16: bool = False, +# use_mxfp4_w4a4: bool = False, +# per_channel_quant: bool = False, +# global_num_experts: int = -1, +# expert_map: Optional[torch.Tensor] = None, +# w1_scale: Optional[torch.Tensor] = None, +# w2_scale: Optional[torch.Tensor] = None, +# w1_zp: Optional[torch.Tensor] = None, +# w2_zp: Optional[torch.Tensor] = None, +# a1_scale: Optional[torch.Tensor] = None, +# a2_scale: Optional[torch.Tensor] = None, +# block_shape: Optional[list[int]] = None) -> None: +# pass +# +# +# direct_register_custom_op( +# op_name="inplace_fused_experts", +# op_func=inplace_fused_experts, +# mutates_args=["hidden_states"], +# fake_impl=inplace_fused_experts_fake, +# tags=(torch.Tag.needs_fixed_stride_order, ), +# ) +# +# +# def outplace_fused_experts( +# hidden_states: torch.Tensor, +# w1: torch.Tensor, +# w2: torch.Tensor, +# topk_weights: torch.Tensor, +# topk_ids: torch.Tensor, +# activation: str = "silu", +# apply_router_weight_on_input: bool = False, +# use_fp8_w8a8: bool = False, +# use_int8_w8a8: bool = False, +# use_int8_w8a16: bool = False, +# use_int4_w4a16: bool = False, +# use_mxfp4_w4a4: bool = False, +# per_channel_quant: bool = False, +# global_num_experts: int = -1, +# expert_map: Optional[torch.Tensor] = None, +# w1_scale: Optional[torch.Tensor] = None, +# w2_scale: Optional[torch.Tensor] = None, +# w1_zp: Optional[torch.Tensor] = None, +# w2_zp: Optional[torch.Tensor] = None, +# a1_scale: Optional[torch.Tensor] = None, +# a2_scale: Optional[torch.Tensor] = None, +# block_shape: Optional[list[int]] = None) -> torch.Tensor: +# return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, +# False, activation, apply_router_weight_on_input, +# use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, +# use_int4_w4a16, use_mxfp4_w4a4, +# per_channel_quant, global_num_experts, +# expert_map, w1_scale, w2_scale, w1_zp, w2_zp, +# a1_scale, a2_scale, block_shape) +# +# +# def outplace_fused_experts_fake( +# hidden_states: torch.Tensor, +# w1: torch.Tensor, +# w2: torch.Tensor, +# topk_weights: torch.Tensor, +# topk_ids: torch.Tensor, +# activation: str = "silu", +# use_fp8_w8a8: bool = False, +# use_int8_w8a8: bool = False, +# use_int8_w8a16: bool = False, +# use_int4_w4a16: bool = False, +# use_mxfp4_w4a4: bool = False, +# per_channel_quant: bool = False, +# global_num_experts: int = -1, +# expert_map: Optional[torch.Tensor] = None, +# w1_scale: Optional[torch.Tensor] = None, +# w2_scale: Optional[torch.Tensor] = None, +# w1_zp: Optional[torch.Tensor] = None, +# w2_zp: Optional[torch.Tensor] = None, +# a1_scale: Optional[torch.Tensor] = None, +# a2_scale: Optional[torch.Tensor] = None, +# block_shape: Optional[list[int]] = None) -> torch.Tensor: +# return torch.empty_like(hidden_states) +# +# +# direct_register_custom_op( +# op_name="outplace_fused_experts", +# op_func=outplace_fused_experts, +# mutates_args=[], +# fake_impl=outplace_fused_experts_fake, +# tags=(torch.Tag.needs_fixed_stride_order, ), +# ) +# +# +# def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: +# torch.ops.vllm.inplace_fused_experts(**kwargs) +# hidden_states = kwargs['hidden_states'] +# return hidden_states +# +# +# def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: +# return torch.ops.vllm.outplace_fused_experts(**kwargs) +# +# +# def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: +# if inplace: +# return torch_vllm_inplace_fused_experts +# return torch_vllm_outplace_fused_experts + + +# # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace +# # torch ops. +# def fused_experts( +# hidden_states: torch.Tensor, +# w1: torch.Tensor, +# w2: torch.Tensor, +# topk_weights: torch.Tensor, +# topk_ids: torch.Tensor, +# inplace: bool = False, +# activation: str = "silu", +# apply_router_weight_on_input: bool = False, +# use_fp8_w8a8: bool = False, +# use_int8_w8a8: bool = False, +# use_int8_w8a16: bool = False, +# use_int4_w4a16: bool = False, +# use_mxfp4_w4a4: bool = False, +# per_channel_quant: bool = False, +# global_num_experts: int = -1, +# expert_map: Optional[torch.Tensor] = None, +# w1_scale: Optional[torch.Tensor] = None, +# w2_scale: Optional[torch.Tensor] = None, +# w1_zp: Optional[torch.Tensor] = None, +# w2_zp: Optional[torch.Tensor] = None, +# a1_scale: Optional[torch.Tensor] = None, +# a2_scale: Optional[torch.Tensor] = None, +# block_shape: Optional[list[int]] = None, +# allow_deep_gemm: bool = False, +# allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: +# # For now, disable DeepGemm for small N (<= 512) until better +# # permute/unpermute ops are available. +# # However, on B200, we use DeepGemm for all cases because they only support +# # E8M0 scale, which means we requantize the weight and input to the specific +# # scale. Fallen back to cutlass or triton for some cases would cause +# # accuracy issue. +# N = w1.size(1) +# should_use_deep_gemm = ((N > 512 +# and _valid_deep_gemm(hidden_states, w1, w2)) +# or is_blackwell_deep_gemm_used()) +# if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): +# assert apply_router_weight_on_input is False +# return deep_gemm_moe_fp8( +# hidden_states=hidden_states, +# w1=w1, +# w2=w2, +# topk_weights=topk_weights, +# topk_ids=topk_ids, +# inplace=inplace, +# activation=activation, +# global_num_experts=global_num_experts, +# expert_map=expert_map, +# w1_scale=w1_scale, +# w2_scale=w2_scale, +# a1_scale=a1_scale, +# a2_scale=a2_scale, +# apply_router_weight_on_input=apply_router_weight_on_input, +# ) +# elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 +# and _valid_cutlass_block_scaled_grouped_gemm( +# w1, w2, inplace, activation, apply_router_weight_on_input, +# expert_map)): +# return run_cutlass_block_scaled_fused_experts( +# a=hidden_states, +# w1=w1, +# w2=w2, +# w1_scale=w1_scale, +# w2_scale=w2_scale, +# topk_weights=topk_weights, +# topk_ids=topk_ids) +# else: +# return dispatch_fused_experts_func(inplace)( +# hidden_states=hidden_states, +# w1=w1, +# w2=w2, +# topk_weights=topk_weights, +# topk_ids=topk_ids, +# activation=activation, +# apply_router_weight_on_input=apply_router_weight_on_input, +# use_fp8_w8a8=use_fp8_w8a8, +# use_int8_w8a8=use_int8_w8a8, +# use_int8_w8a16=use_int8_w8a16, +# use_int4_w4a16=use_int4_w4a16, +# use_mxfp4_w4a4=use_mxfp4_w4a4, +# per_channel_quant=per_channel_quant, +# global_num_experts=global_num_experts, +# expert_map=expert_map, +# w1_scale=w1_scale, +# w2_scale=w2_scale, +# w1_zp=w1_zp, +# w2_zp=w2_zp, +# a1_scale=a1_scale, +# a2_scale=a2_scale, +# block_shape=block_shape) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + use_default_config = False, +) -> torch.Tensor: + # Check constraints. + if use_int4_w4a16: + assert hidden_states.size(1) // 2 == w1.size(2), ( + "Hidden size mismatch") + elif use_mxfp4_w4a4: + # 16bit activation and fp4x2 packed weight + assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" + else: + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") + + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + num_tokens = hidden_states.size(0) + E, N, _ = w1.size() + K = w2.size(1) + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.size(1) + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, + dtype=hidden_states.dtype) + + qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.size(), + w2.size(), + top_k_num, + config_dtype, + block_shape=block_shape, + force_default=use_default_config, + ) + + config = get_config_func(M) + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) + + # This needs separate memory since it's used concurrently with cache1 + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + if use_mxfp4_w4a4: + # Weight has to be dequantized for mxfp4 emulation. + w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) + w1_scale = None + w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) + w2_scale = None + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.size() + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.size(1)] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( + A=curr_hidden_states, + A_scale=a1_scale, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, + block_shape=block_shape) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + A=intermediate_cache2, + A_scale=a2_scale, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, + block_shape=block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape) + + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) + + return out_hidden_states + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + activation: str = "silu", + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + use_default_config=True, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and + OCP MXFP4 activation to compute the inner products for w1 and w2. + Defaults to False. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[list[int]]): Optional block size for block-wise + quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + elif custom_routing_function is None: + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) + + return fused_experts_impl(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + use_default_config=use_default_config, + ) + + +# class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +# +# def __init__( +# self, +# use_fp8_w8a8: bool = False, +# use_int8_w8a8: bool = False, +# use_int8_w8a16: bool = False, +# use_int4_w4a16: bool = False, +# use_mxfp4_w4a4: bool = False, +# per_act_token_quant: bool = False, +# block_shape: Optional[list[int]] = None, +# ): +# super().__init__( +# FusedMoEQuantConfig.make( +# use_fp8_w8a8=use_fp8_w8a8, +# use_int8_w8a8=use_int8_w8a8, +# use_int8_w8a16=use_int8_w8a16, +# use_int4_w4a16=use_int4_w4a16, +# use_mxfp4_w4a4=use_mxfp4_w4a4, +# per_act_token_quant=per_act_token_quant, +# block_shape=block_shape, +# )) +# +# self.use_fp8_w8a8 = use_fp8_w8a8 +# self.use_int4_w4a16 = use_int4_w4a16 +# self.use_int8_w8a8 = use_int8_w8a8 +# self.use_int8_w8a16 = use_int8_w8a16 +# self.use_mxfp4_w4a4 = use_mxfp4_w4a4 +# +# @property +# def activation_formats( +# self +# ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: +# return (mk.FusedMoEActivationFormat.Standard, +# mk.FusedMoEActivationFormat.Standard) +# +# def supports_chunking(self) -> bool: +# return True +# +# def supports_expert_map(self) -> bool: +# return True +# +# def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: +# return TopKWeightAndReduceNoOP() +# +# def workspace_shapes( +# self, +# a: torch.Tensor, +# aq: torch.Tensor, +# M: int, +# N: int, +# K: int, +# topk: int, +# global_num_experts: int, +# local_num_experts: int, +# expert_tokens_meta: Optional[mk.ExpertTokensMetadata], +# ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: +# workspace1 = (M, topk, max(N // 2, K)) +# workspace2 = (M, topk, max(N, K)) +# output = (M, K) +# return (workspace1, workspace2, output, a.dtype) +# +# def apply( +# self, +# output: torch.Tensor, +# hidden_states: torch.Tensor, +# w1: torch.Tensor, +# w2: torch.Tensor, +# topk_weights: torch.Tensor, +# topk_ids: torch.Tensor, +# activation: str, +# global_num_experts: int, +# expert_map: Optional[torch.Tensor], +# w1_scale: Optional[torch.Tensor], +# w2_scale: Optional[torch.Tensor], +# w1_zp: Optional[torch.Tensor], +# w2_zp: Optional[torch.Tensor], +# a1q_scale: Optional[torch.Tensor], +# a2_scale: Optional[torch.Tensor], +# workspace13: torch.Tensor, +# workspace2: torch.Tensor, +# expert_tokens_meta: Optional[mk.ExpertTokensMetadata], +# apply_router_weight_on_input: bool, +# ): +# # Check constraints. +# if self.use_int4_w4a16: +# assert hidden_states.size(-1) // 2 == w1.size(2), ( +# "Hidden size mismatch") +# else: +# assert hidden_states.size(-1) == w1.size(2), \ +# (f"Hidden size mismatch {hidden_states.size(-1)} " +# f"!= {w1.size(2)}") +# +# assert hidden_states.is_contiguous( +# ), "Hidden_states must be contiguous" +# assert hidden_states.dim() == 2 +# assert w1.stride(-1) == 1, "Stride of last dimension must be 1" +# assert w2.stride(-1) == 1, "Stride of last dimension must be 1" +# assert hidden_states.dtype in [ +# torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn +# ] +# +# E, num_tokens, N, K, top_k_num = mk._moe_problem_size( +# hidden_states, w1, w2, topk_ids) +# +# if global_num_experts == -1: +# global_num_experts = E +# +# config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, +# use_int8_w8a16=self.use_int8_w8a16, +# use_int4_w4a16=self.use_int4_w4a16, +# use_mxfp4_w4a4=self.use_mxfp4_w4a4, +# dtype=hidden_states.dtype) +# +# config = try_get_optimal_moe_config( +# w1.size(), +# w2.size(), +# top_k_num, +# config_dtype, +# num_tokens, +# block_shape=self.block_shape, +# ) +# +# if hidden_states.dtype == torch.bfloat16: +# compute_type = tl.bfloat16 +# elif hidden_states.dtype == torch.float16: +# compute_type = tl.float16 +# elif hidden_states.dtype == torch.float32: +# compute_type = tl.float32 +# elif hidden_states.dtype == torch.float8_e4m3fn: +# compute_type = tl.bfloat16 +# else: +# raise ValueError( +# f"Unsupported compute_type: {hidden_states.dtype}") +# +# # Note that the output tensor might be in workspace1 +# intermediate_cache1 = _resize_cache(workspace2, +# (num_tokens, top_k_num, N)) +# intermediate_cache2 = _resize_cache(workspace13, +# (num_tokens * top_k_num, N // 2)) +# intermediate_cache3 = _resize_cache(workspace2, +# (num_tokens, top_k_num, K)) +# +# sorted_token_ids, expert_ids, num_tokens_post_padded = ( +# moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], +# global_num_experts, expert_map)) +# +# invoke_fused_moe_kernel( +# hidden_states, +# w1, +# intermediate_cache1, +# a1q_scale, +# w1_scale, +# w1_zp, +# None, # topk_weights +# sorted_token_ids, +# expert_ids, +# num_tokens_post_padded, +# False, # mul_routed_weights +# top_k_num, +# config, +# compute_type=compute_type, +# use_fp8_w8a8=self.use_fp8_w8a8, +# use_int8_w8a8=self.use_int8_w8a8, +# use_int8_w8a16=self.use_int8_w8a16, +# use_int4_w4a16=self.use_int4_w4a16, +# per_channel_quant=self.per_act_token_quant, +# block_shape=self.block_shape) +# +# self.activation(activation, intermediate_cache2, +# intermediate_cache1.view(-1, N)) +# +# a2q_scale: Optional[torch.Tensor] = None +# +# qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( +# intermediate_cache2, a2_scale, self.quant_dtype, +# self.per_act_token_quant, self.block_shape) +# +# invoke_fused_moe_kernel(qintermediate_cache2, +# w2, +# intermediate_cache3, +# a2q_scale, +# w2_scale, +# w2_zp, +# topk_weights, +# sorted_token_ids, +# expert_ids, +# num_tokens_post_padded, +# not apply_router_weight_on_input, +# 1, +# config, +# compute_type=compute_type, +# use_fp8_w8a8=self.use_fp8_w8a8, +# use_int8_w8a8=self.use_int8_w8a8, +# use_int8_w8a16=self.use_int8_w8a16, +# use_int4_w4a16=self.use_int4_w4a16, +# per_channel_quant=self.per_act_token_quant, +# block_shape=self.block_shape) +# +# ops.moe_sum(intermediate_cache3, output) +# + +# def modular_triton_fused_moe( +# use_fp8_w8a8: bool, +# use_int8_w8a8: bool, +# use_int8_w8a16: bool, +# use_int4_w4a16: bool, +# use_mxfp4_w4a4: bool, +# per_act_token_quant: bool, +# block_shape: Optional[list[int]] = None, +# ) -> mk.FusedMoEModularKernel: +# return mk.FusedMoEModularKernel( +# MoEPrepareAndFinalizeNoEP(), +# TritonExperts( +# use_fp8_w8a8=use_fp8_w8a8, +# use_int8_w8a8=use_int8_w8a8, +# use_int8_w8a16=use_int8_w8a16, +# use_int4_w4a16=use_int4_w4a16, +# use_mxfp4_w4a4=use_mxfp4_w4a4, +# per_act_token_quant=per_act_token_quant, +# block_shape=block_shape, +# ), +# ) diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_chunked_prefill_paged_decode.py b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_chunked_prefill_paged_decode.py index f09b6ba9f..beb911396 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_chunked_prefill_paged_decode.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_chunked_prefill_paged_decode.py @@ -102,16 +102,17 @@ def chunked_prefill_paged_decode( HEAD_SIZE_PADDED=next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, SLIDING_WINDOW=sliding_window_int, - x=key_cache.shape[4], + x=key_cache.shape[4] if len(key_cache.shape) == 5 else 1, stride_k_cache_0=key_cache.stride(0), stride_k_cache_1=key_cache.stride(1), stride_k_cache_2=key_cache.stride(2), stride_k_cache_3=key_cache.stride(3), - stride_k_cache_4=key_cache.stride(4), + stride_k_cache_4=key_cache.stride(4) if len(key_cache.shape) == 5 else 1, stride_v_cache_0=value_cache.stride(0), stride_v_cache_1=value_cache.stride(1), stride_v_cache_2=value_cache.stride(2), stride_v_cache_3=value_cache.stride(3), filter_by_query_len=True, query_start_len_ptr=query_start_loc, + # num_seqs=num_seqs, ) diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py index eb03b53c2..410a3ddeb 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/legacy/triton_paged_decode_attention_2d.py @@ -71,31 +71,32 @@ def cdiv_fn(x, y): return (x + y - 1) // y -@triton_dejavu.jitcache( - # remove cache_lock if dynamic cache mode should be used - cache_lock=global_cache_lock, - # list of `tl.constexpr` that should be used as cache index - check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"], - check_specialization=["num_seqs", "stride_k_cache_3", "stride_v_cache_3"], - assume_const=[ - "scale", - "k_scale", - "v_scale", - "query_stride_1", - "output_stride_1", - "stride_k_cache_0", - "stride_k_cache_1", - "stride_k_cache_2", - "stride_k_cache_4", - "stride_v_cache_0", - "stride_v_cache_1", - "stride_v_cache_2", - ], - # besides this checks and assumed constants, - # the cache just binds all non_const_expr - cache_launch_grid=True, -) -@triton.jit(launch_metadata=metadata_fn) +# @triton_dejavu.jitcache( +# # remove cache_lock if dynamic cache mode should be used +# cache_lock=global_cache_lock, +# # list of `tl.constexpr` that should be used as cache index +# check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"], +# check_specialization=["num_seqs", "stride_k_cache_3", "stride_v_cache_3"], +# assume_const=[ +# "scale", +# "k_scale", +# "v_scale", +# "query_stride_1", +# "output_stride_1", +# "stride_k_cache_0", +# "stride_k_cache_1", +# "stride_k_cache_2", +# "stride_k_cache_4", +# "stride_v_cache_0", +# "stride_v_cache_1", +# "stride_v_cache_2", +# ], +# # besides this checks and assumed constants, +# # the cache just binds all non_const_expr +# cache_launch_grid=True, +# ) +# @triton.jit(launch_metadata=metadata_fn) +@triton.jit def kernel_paged_attention_2d( # TODO: as soon as fixed in triton: add tl.pointer_type annotation output_ptr, #: tl.pointer_type, # [num_tokens, num_query_heads, head_size] @@ -133,11 +134,11 @@ def kernel_paged_attention_2d( stride_v_cache_3: tl.int64, # int filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, #: tl.pointer_type, # [num_seqs+1] - num_seqs: int, + # num_seqs: int, ): seq_idx = tl.program_id(0) - if seq_idx >= num_seqs: - return + # if seq_idx >= num_seqs: + # return kv_head_idx = tl.program_id(1) if filter_by_query_len: @@ -352,10 +353,10 @@ def paged_attention_triton_2d( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - assert num_seqs <= 4096 + # assert num_seqs <= 4096 kernel_paged_attention_2d[ ( - 4096, + num_seqs, num_kv_heads, ) ]( @@ -394,5 +395,5 @@ def paged_attention_triton_2d( stride_v_cache_3=value_cache.stride(3), filter_by_query_len=False, query_start_len_ptr=None, - num_seqs=num_seqs, + # num_seqs=num_seqs, ) diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/mamba_ssm.py b/ibm-triton-lib/ibm_triton_lib/kernels/mamba_ssm.py index 79e519b2d..0bcbfdfb5 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/mamba_ssm.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/mamba_ssm.py @@ -85,7 +85,7 @@ def fallback_heuristic_simple(key): config_space=triton_dejavu.ConfigSpace( {"BLOCK_SIZE_M": [4, 8, 16, 32, 64]}, num_warps=[2, 4, 8], - num_stages=[1, 2, 4, 6, 8], + num_stages=[1, 2, 3, 4, 5, 6, 8], ), key=[ "dstate", diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py index 2f6911317..713db70f6 100644 --- a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention.py @@ -8,12 +8,11 @@ # - Thomas Parnell import torch -import triton -import triton.language as tl -import os -import triton_dejavu -import functools +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton + +logger = init_logger(__name__) @triton.jit @@ -30,13 +29,8 @@ def apply_softcap(S, x): @triton.jit -def find_seq_idx( - query_start_len_ptr, - target_idx, - num_seqs, - BLOCK_Q: tl.constexpr, - use_q_block_mode: tl.constexpr, -): +def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, + BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): left: tl.int32 = 0 right = num_seqs while left < right: @@ -52,288 +46,61 @@ def find_seq_idx( return left - 1 -# not as lambda, for python3.9 -def fallback_heuristic_dt2(key): - tpa_test_q = key[1] - tpa_test_k = key[2] - # Model trained on max - if tpa_test_q < 1024: - BLOCK_M = 16 - else: - BLOCK_M = 64 - - if tpa_test_k < 64: - if tpa_test_k < 32: - BLOCK_N = 16 - else: - BLOCK_N = 32 - else: - if tpa_test_q < 256: - BLOCK_N = 128 - else: - BLOCK_N = 64 - ret = triton.Config( - {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N}, num_stages=2, num_warps=8 - ) - # num stages = 2, to be on the safe side for MI300 - return ret - - -def informed_fallback_next(key, cache): - # key[0] = max q - # key[2] = avg q - ret = cache[min(cache.keys(), key=lambda x: abs(x - key[0]))] - return ret - - -def prepare_informed_fallback(cache): - ret = {int(k[0]): c for k, c in cache.items()} - return ret - - -@functools.lru_cache -def prefill_heuristics_2d(MAX_SEQ_Q, MAX_SEQ_K, AVG_SEQ_Q, AVG_SEQ_K): - gpu_name = torch.cuda.get_device_name() - # print(f"MAX_SEQ_Q {MAX_SEQ_Q}, MAX_SEQ_K {MAX_SEQ_K}, AVG_SEQ_Q {AVG_SEQ_Q}, AVG_SEQ_K {AVG_SEQ_K}") - if "NVIDIA H100" in gpu_name: - # # TPA original heuristic - # if MAX_SEQ_Q < 1024: - # BLOCK_M = 16 - # else: - # BLOCK_M = 64 - # if MAX_SEQ_K < 64: - # if MAX_SEQ_K < 32: - # BLOCK_N = 16 - # else: - # BLOCK_N = 32 - # else: - # if MAX_SEQ_Q < 256: - # BLOCK_N = 128 - # else: - # BLOCK_N = 64 - # config = {'num_stages': 3, 'num_warps': 4, - # 'BLOCK_N': BLOCK_N, 'BLOCK_M': BLOCK_M} - # dejavu with microbenchmarks - # TODO: update to latest tuning with AVG - if MAX_SEQ_K <= 96: - config = {"num_stages": 4, "num_warps": 4, "BLOCK_N": 32, "BLOCK_M": 16} - else: - if MAX_SEQ_Q <= 192: - if MAX_SEQ_K <= 1536: - config = { - "num_stages": 2, - "num_warps": 8, - "BLOCK_N": 128, - "BLOCK_M": 16, - } - else: - config = { - "num_stages": 8, - "num_warps": 8, - "BLOCK_N": 128, - "BLOCK_M": 16, - } - else: - config = { - "num_stages": 1, - "num_warps": 8, - "BLOCK_N": 128, - "BLOCK_M": 128, - } - elif "AMD Instinct MI300" in gpu_name: - # dejavu with microbenchmarks - # TODO: update to latest tuning with AVG - if MAX_SEQ_Q <= 384: - if MAX_SEQ_K <= 96: - config = {"num_stages": 4, "num_warps": 4, "BLOCK_N": 32, "BLOCK_M": 16} - else: - if MAX_SEQ_K <= 192: - if MAX_SEQ_Q <= 96: - config = { - "num_stages": 2, - "num_warps": 8, - "BLOCK_N": 128, - "BLOCK_M": 16, - } - else: - config = { - "num_stages": 4, - "num_warps": 4, - "BLOCK_N": 32, - "BLOCK_M": 16, - } - else: - if MAX_SEQ_Q <= 128: - config = { - "num_stages": 4, - "num_warps": 4, - "BLOCK_N": 32, - "BLOCK_M": 16, - } - else: - if MAX_SEQ_K <= 384: - config = { - "num_stages": 4, - "num_warps": 4, - "BLOCK_N": 32, - "BLOCK_M": 16, - } - else: - config = { - "num_stages": 1, - "num_warps": 4, - "BLOCK_N": 256, - "BLOCK_M": 32, - } - else: - if MAX_SEQ_K <= 768: - config = {"num_stages": 4, "num_warps": 4, "BLOCK_N": 16, "BLOCK_M": 64} - else: - config = {"num_stages": 1, "num_warps": 2, "BLOCK_N": 64, "BLOCK_M": 64} - else: - # default - config = { - "BLOCK_M": 64 if MAX_SEQ_Q > 1 and AVG_SEQ_Q >= 4096 else 16, - "BLOCK_N": 16 if MAX_SEQ_K < 128 and AVG_SEQ_Q <= 4096 else 64, - "num_warps": 4, - "num_stages": 3, - } - # print(config) - return config - - -@triton_dejavu.jitcache( - # this list is shorter, since it will be called only within one model - check_keys=[ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "stride_k_cache_3", - "stride_v_cache_3", - ], - check_specialization=["num_seqs"], - assume_const=[ - "scale", - "k_scale", - "v_scale", - "query_stride_1", - "output_stride_1", - "stride_k_cache_0", - "stride_k_cache_1", - "stride_k_cache_2", - "stride_k_cache_4", - "stride_v_cache_0", - "stride_v_cache_1", - "stride_v_cache_2", - ], - autotuner_args=["BLOCK_N", "BLOCK_M"], -) -@triton_dejavu.autotune( - config_space=triton_dejavu.ConfigSpace( - { - "BLOCK_N": [16, 32, 64, 128, 256, 512], - "BLOCK_M": [16, 32, 64, 128, 256, 512], - }, - num_warps=[2, 4, 8], - num_stages=[1, 2, 4, 6, 8], - ), - # this list is longer, since it would be used for multiple models - key=[ - "MAX_SEQ_Q", - "MAX_SEQ_K", - "AVG_SEQ_Q", - "AVG_SEQ_K", - "num_query_heads", - "num_queries_per_kv", - "BLOCK_SIZE", - "HEAD_SIZE", - "HEAD_SIZE_PADDED", - "SLIDING_WINDOW", - "stride_k_cache_3", - "stride_v_cache_3", - ], - custom_data_storage=os.path.abspath( - os.path.join(os.path.dirname(__file__), "dejavu_data") - ), - use_cuda_graph=True, - use_bo=True, - search_max_search_t=360, - informed_fallback=informed_fallback_next, - prepare_informed_fallback=prepare_informed_fallback, - fallback_heuristic=fallback_heuristic_dt2, - ignore_dtypes=True, -) -# @triton.heuristics( -# { -# "BLOCK_M": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['BLOCK_M'], -# "BLOCK_N": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['BLOCK_N'], -# "num_warps": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['num_warps'], -# "num_stages": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['num_stages'], -# } -# ) @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - num_seqs: tl.int32, - # used as input to the autotuner/heuristics - MAX_SEQ_Q: tl.constexpr, - MAX_SEQ_K: tl.constexpr, - AVG_SEQ_Q: tl.constexpr, - AVG_SEQ_K: tl.constexpr, - # autotuner args - BLOCK_M: tl.constexpr, # int - BLOCK_N: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int ): - q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - BLOCK_Q = BLOCK_M // num_queries_per_kv - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -343,12 +110,10 @@ def kernel_unified_attention_2d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ( - query_offset_0[:, None] * query_stride_0 - + query_offset_1[:, None] * query_stride_1 - + offs_d[None, :] - ) + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -375,61 +140,45 @@ def kernel_unified_attention_2d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load( - alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 - ) + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = ( - context_len - + q_block_local_idx * BLOCK_Q - + (BLOCK_M - 1) // num_queries_per_kv - + 1 - ) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 # adjust for potential padding in the last q_block by considering the # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - offs_n = tl.arange(0, BLOCK_N) - - # iterate through tiles (below the mask) - # The loop iterates only until the longest sequence. Due to causal - # masking, blocks beyond this prefix can be skipped. - for start_n in range(0, max_seq_prefix_len, BLOCK_N): + # calculate the number of tiles (blocks) that need to be processed to + # cover the longest sequence prefix (due to causal masking, blocks beyond + # this prefix can be skipped) + num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) - start_n = tl.multiple_of(start_n, BLOCK_N) + # iterate through tiles + for j in range(0, num_blocks): - physical_block_idx = tl.load( - block_tables_ptr + block_table_offset + (start_n + offs_n) // BLOCK_SIZE, - mask=(start_n + offs_n) < seq_len, - other=0, - ) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) - v_offset = ( - physical_block_idx[:, None] * stride_v_cache_0 - + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 - + (offs_n[:, None] % BLOCK_SIZE) * stride_v_cache_1 - ) + offs_n = tl.arange(0, BLOCK_SIZE) - k_offset = ( - physical_block_idx[None, :] * stride_k_cache_0 - + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 - + (offs_n[None, :] % BLOCK_SIZE) * stride_k_cache_1 - ) + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) - seq_offset_load = start_n + offs_n - load_mask = seq_offset_load < max_seq_prefix_len + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE_PADDED, BLOCK_N) - K_load = tl.load( - key_cache_ptr + k_offset, - mask=dim_mask[:, None] & load_mask[None, :], - other=0.0, - ) + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -439,12 +188,10 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_N, HEAD_SIZE_PADDED) - V_load = tl.load( - value_cache_ptr + v_offset, - mask=dim_mask[None, :] & load_mask[:, None], - other=0.0, - ) + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -454,29 +201,24 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = start_n + tl.arange(0, BLOCK_N) + seq_offset = j * BLOCK_SIZE + offs_n - # seq_mask: (BLOCK_M, BLOCK_N) seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_N) - S = tl.zeros(shape=(BLOCK_M, BLOCK_N), dtype=tl.float32) + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where( - query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") - ) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) if SLIDING_WINDOW > 0: - S = tl.where( - (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, - S, - float("-inf"), - ) + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -488,7 +230,7 @@ def kernel_unified_attention_2d( # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_N) + # P : (BLOCK_M, BLOCK_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -510,11 +252,9 @@ def kernel_unified_attention_2d( # epilogue acc = acc / L[:, None] - output_offset = ( - query_offset_0[:, None] * output_stride_0 - + query_offset_1[:, None] * output_stride_1 - + offs_d[None, :] - ) + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) tl.store( output_ptr + output_offset, @@ -525,61 +265,62 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( - segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -600,13 +341,11 @@ def kernel_unified_attention_3d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv - query_offset = ( - query_offset_0[:, None] * query_stride_0 - + query_offset_1[:, None] * query_stride_1 - + offs_d[None, :] - ) + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -630,37 +369,35 @@ def kernel_unified_attention_3d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load( - alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 - ) + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * blocks_per_segment, + min((segm_idx + 1) * blocks_per_segment, num_blocks), ): physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) offs_n = tl.arange(0, BLOCK_SIZE) - v_offset = ( - physical_block_idx * stride_v_cache_0 - + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 - + offs_n[:, None] * stride_v_cache_1 - ) + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) - k_offset = ( - physical_block_idx * stride_k_cache_0 - + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 - + offs_n[None, :] * stride_k_cache_1 - ) + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) # K : (HEAD_SIZE, BLOCK_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -671,7 +408,9 @@ def kernel_unified_attention_3d( K = K_load # V : (BLOCK_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -693,16 +432,12 @@ def kernel_unified_attention_3d( if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where( - query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") - ) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) if SLIDING_WINDOW > 0: - S = tl.where( - (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, - S, - float("-inf"), - ) + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -734,52 +469,49 @@ def kernel_unified_attention_3d( acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( - query_offset_0[:, None].to(tl.int64) - * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) - + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) - + segm_idx * HEAD_SIZE_PADDED - + tl.arange(0, HEAD_SIZE_PADDED)[None, :] - ) + query_offset_0[:, None].to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) tl.store( segm_output_ptr + segm_output_offset, acc, mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ) - segm_offset = ( - query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) - + query_offset_1 * NUM_SEGMENTS_PER_SEQ - + segm_idx - ) + segm_offset = (query_offset_0.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) - tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, + L, + mask=query_mask_0 & query_mask_1) @triton.jit def reduce_segments( - output_ptr, # [num_tokens, num_query_heads, head_size] - segm_output_ptr, - # [num_tokens, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] - seq_lens_ptr, # [num_seqs] - num_seqs, # int - num_query_heads: tl.constexpr, # int - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False - ) + seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, + BLOCK_Q, False) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -791,32 +523,34 @@ def reduce_segments( # create masks for subsequent loads act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( - [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 - ) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) # load segment maxima - segm_offset = ( - query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) - + query_head_idx * NUM_SEGMENTS_PER_SEQ - + tl.arange(0, NUM_SEGMENTS_PER_SEQ) - ) - segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + segm_offset = (query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) + segm_max = tl.load(segm_max_ptr + segm_offset, + mask=segm_mask, + other=float("-inf")) overall_max = tl.max(segm_max) # load and rescale segment exp sums - segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, + mask=segm_mask, + other=0.0) segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) overall_expsum = tl.sum(segm_expsum) # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) - * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) - + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) - + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED - + tl.arange(0, HEAD_SIZE_PADDED)[None, :] - ) + query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) segm_output = tl.load( segm_output_ptr + segm_output_offset, mask=segm_mask[:, None] & dim_mask[None, :], @@ -828,11 +562,9 @@ def reduce_segments( acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) # write result - output_offset = ( - query_token_idx * output_stride_0 - + query_head_idx * output_stride_1 - + tl.arange(0, HEAD_SIZE_PADDED) - ) + output_offset = (query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED)) tl.store(output_ptr + output_offset, acc, mask=dim_mask) @@ -845,8 +577,6 @@ def unified_attention( max_seqlen_q, seqused_k, max_seqlen_k, - avg_seqlen_q, - avg_seqlen_k, softmax_scale, causal, window_size, @@ -862,9 +592,8 @@ def unified_attention( assert q_descale is None, "Q scales not supported" block_size = v.shape[1] - assert ( - q.element_size() >= 2 or block_size >= 32 - ), "Block size must be at least 32 for fp8" + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" use_alibi_slopes = alibi_slopes is not None @@ -875,20 +604,27 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - MAX_SEQ_Q = triton.next_power_of_2(int(max_seqlen_q)) - MAX_SEQ_K = triton.next_power_of_2(int(max_seqlen_k)) - AVG_SEQ_Q = triton.next_power_of_2(int(avg_seqlen_q)) - AVG_SEQ_K = triton.next_power_of_2(int(avg_seqlen_k)) + BLOCK_M = 16 + BLOCK_Q = BLOCK_M // num_queries_per_kv - # if batch contains a prefill - if max_seqlen_q > 1 or force_selection == 2 and force_selection != 3: + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - grid = lambda META: ( - q.shape[0] // (META["BLOCK_M"] // num_queries_per_kv) + num_seqs, + # if batch contains a prefill + # if (max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128 or force_selection == 2) and force_selection != 3: + if force_selection == 2: + kernel_unified_attention_2d[( + total_num_q_blocks, num_kv_heads, - ) - - kernel_unified_attention_2d[grid]( + )]( output_ptr=out, query_ptr=q, key_cache_ptr=k, @@ -922,27 +658,11 @@ def unified_attention( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, - MAX_SEQ_Q=MAX_SEQ_Q, - MAX_SEQ_K=MAX_SEQ_K, - AVG_SEQ_Q=AVG_SEQ_Q, - AVG_SEQ_K=AVG_SEQ_K, + BLOCK_M=BLOCK_M, ) - else: - BLOCK_M = 64 if max_seqlen_q > 1 and avg_seqlen_q >= 4096 else 16 - BLOCK_Q = BLOCK_M // num_queries_per_kv - - # Ideally we would launch with kernel with: - # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. - # However, it is slow to realize the query_lens on cpu. - # Instead we use upper-bound: - # \sum_i[ceil(query_len[i] / BLOCK_Q)] - # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] - # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs - # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs - # = floor(q.shape[0] / BLOCK_Q) + num_seqs - total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - + elif force_selection == 3: # for initial version, NUM_SEGMENTS = 16 is chosen as a default # value that showed good performance in tests NUM_SEGMENTS = 16 @@ -970,45 +690,46 @@ def unified_attention( device=q.device, ) - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_SOFTCAP=(softcap > 0), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) + kernel_unified_attention_3d[( + total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, @@ -1028,3 +749,6 @@ def unified_attention( BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) + else: + raise RuntimeError("currently, we need to force a kernel selection") + diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention_simple.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention_simple.py new file mode 100644 index 000000000..cbbc52250 --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention_simple.py @@ -0,0 +1,756 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch + +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, + BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + BLOCK_N: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + offs_n = tl.arange(0, BLOCK_N) + + # iterate through tiles (below the mask) + # The loop iterates only until the longest sequence. Due to causal + # masking, blocks beyond this prefix can be skipped. + for start_n in range(0, max_seq_prefix_len, BLOCK_N): + + start_n = tl.multiple_of(start_n, BLOCK_N) + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + (start_n + offs_n) // BLOCK_SIZE, + mask=(start_n + offs_n) < seq_len, + other=0) + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (offs_n[:, None] % BLOCK_SIZE) * stride_v_cache_1) + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (offs_n[None, :] % BLOCK_SIZE) * stride_k_cache_1) + + # K : (HEAD_SIZE_PADDED, BLOCK_N) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_N, HEAD_SIZE_PADDED) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = start_n + tl.arange(0, BLOCK_N) + + # seq_mask: (BLOCK_M, BLOCK_N) + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_N) + S = tl.zeros(shape=(BLOCK_M, BLOCK_N), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_N) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * blocks_per_segment, + min((segm_idx + 1) * blocks_per_segment, num_blocks), + ): + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = (query_offset_0.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, + L, + mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, + BLOCK_Q, False) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # load segment maxima + segm_offset = (query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) + segm_max = tl.load(segm_max_ptr + segm_offset, + mask=segm_mask, + other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, + mask=segm_mask, + other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + # write result + output_offset = (query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED)) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + # balancing the blocksizes for short and long prompts + BLOCK_M = 16 + BLOCK_N = block_size + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + # if batch contains a prefill + grid = lambda META: (q.shape[0] // (META[ + 'BLOCK_M'] // num_queries_per_kv) + num_seqs, num_kv_heads) + + kernel_unified_attention_2d[grid]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=4, + ) + # else: + # # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # # value that showed good performance in tests + # NUM_SEGMENTS = 16 + + # segm_output = torch.empty( + # q.shape[0], + # num_query_heads, + # NUM_SEGMENTS, + # triton.next_power_of_2(head_size), + # dtype=torch.float32, + # device=q.device, + # ) + # segm_max = torch.empty( + # q.shape[0], + # num_query_heads, + # NUM_SEGMENTS, + # dtype=torch.float32, + # device=q.device, + # ) + # segm_expsum = torch.empty( + # q.shape[0], + # num_query_heads, + # NUM_SEGMENTS, + # dtype=torch.float32, + # device=q.device, + # ) + + # kernel_unified_attention_3d[( + # total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + # segm_output_ptr=segm_output, + # segm_max_ptr=segm_max, + # segm_expsum_ptr=segm_expsum, + # query_ptr=q, + # key_cache_ptr=k, + # value_cache_ptr=v, + # block_tables_ptr=block_table, + # seq_lens_ptr=seqused_k, + # alibi_slopes_ptr=alibi_slopes, + # scale=softmax_scale, + # k_scale=k_descale, + # v_scale=v_descale, + # softcap=softcap, + # num_query_heads=num_query_heads, + # num_queries_per_kv=num_queries_per_kv, + # block_table_stride=block_table.stride(0), + # query_stride_0=q.stride(0), + # query_stride_1=q.stride(1), + # BLOCK_SIZE=block_size, + # HEAD_SIZE=head_size, + # HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + # USE_ALIBI_SLOPES=use_alibi_slopes, + # USE_SOFTCAP=(softcap > 0), + # SLIDING_WINDOW=(1 + window_size[0]), + # stride_k_cache_0=k.stride(0), + # stride_k_cache_1=k.stride(1), + # stride_k_cache_2=k.stride(2), + # stride_k_cache_3=k.stride(3), + # stride_v_cache_0=v.stride(0), + # stride_v_cache_1=v.stride(1), + # stride_v_cache_2=v.stride(2), + # stride_v_cache_3=v.stride(3), + # query_start_len_ptr=cu_seqlens_q, + # BLOCK_Q=BLOCK_Q, + # num_seqs=num_seqs, + # BLOCK_M=BLOCK_M, + # NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + # ) + + # reduce_segments[(q.shape[0], num_query_heads)]( + # output_ptr=out, + # segm_output_ptr=segm_output, + # segm_max_ptr=segm_max, + # segm_expsum_ptr=segm_expsum, + # seq_lens_ptr=seqused_k, + # num_seqs=num_seqs, + # num_query_heads=num_query_heads, + # output_stride_0=out.stride(0), + # output_stride_1=out.stride(1), + # block_table_stride=block_table.stride(0), + # BLOCK_SIZE=block_size, + # HEAD_SIZE=head_size, + # HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + # query_start_len_ptr=cu_seqlens_q, + # BLOCK_Q=BLOCK_Q, + # NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + # ) diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention_tuned.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention_tuned.py new file mode 100644 index 000000000..896114a7d --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_attention_tuned.py @@ -0,0 +1,1045 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch +import triton +import triton.language as tl + +import os +import triton_dejavu +import functools + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +# not as lambda, for python3.9 +def fallback_heuristic_dt2(key): + tpa_test_q = key[1] + tpa_test_k = key[2] + # Model trained on max + if tpa_test_q < 1024: + BLOCK_M = 16 + else: + BLOCK_M = 64 + + if tpa_test_k < 64: + if tpa_test_k < 32: + BLOCK_N = 16 + else: + BLOCK_N = 32 + else: + if tpa_test_q < 256: + BLOCK_N = 128 + else: + BLOCK_N = 64 + ret = triton.Config( + {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N}, num_stages=2, num_warps=8 + ) + # num stages = 2, to be on the safe side for MI300 + return ret + + +def informed_fallback_next(key, cache): + # key[0] = max q + # key[2] = avg q + ret = cache[min(cache.keys(), key=lambda x: abs(x - key[0]))] + return ret + + +def prepare_informed_fallback(cache): + ret = {int(k[0]): c for k, c in cache.items()} + return ret + + +@functools.lru_cache +def prefill_heuristics_2d(MAX_SEQ_Q, MAX_SEQ_K, AVG_SEQ_Q, AVG_SEQ_K): + gpu_name = torch.cuda.get_device_name() + # print(f"MAX_SEQ_Q {MAX_SEQ_Q}, MAX_SEQ_K {MAX_SEQ_K}, AVG_SEQ_Q {AVG_SEQ_Q}, AVG_SEQ_K {AVG_SEQ_K}") + if "NVIDIA H100" in gpu_name: + # # TPA original heuristic + # if MAX_SEQ_Q < 1024: + # BLOCK_M = 16 + # else: + # BLOCK_M = 64 + # if MAX_SEQ_K < 64: + # if MAX_SEQ_K < 32: + # BLOCK_N = 16 + # else: + # BLOCK_N = 32 + # else: + # if MAX_SEQ_Q < 256: + # BLOCK_N = 128 + # else: + # BLOCK_N = 64 + # config = {'num_stages': 3, 'num_warps': 4, + # 'BLOCK_N': BLOCK_N, 'BLOCK_M': BLOCK_M} + # dejavu with microbenchmarks + # TODO: update to latest tuning with AVG + if MAX_SEQ_K <= 96: + config = {"num_stages": 4, "num_warps": 4, "BLOCK_N": 32, "BLOCK_M": 16} + else: + if MAX_SEQ_Q <= 192: + if MAX_SEQ_K <= 1536: + config = { + "num_stages": 2, + "num_warps": 8, + "BLOCK_N": 128, + "BLOCK_M": 16, + } + else: + config = { + "num_stages": 8, + "num_warps": 8, + "BLOCK_N": 128, + "BLOCK_M": 16, + } + else: + config = { + "num_stages": 1, + "num_warps": 8, + "BLOCK_N": 128, + "BLOCK_M": 128, + } + elif "AMD Instinct MI300" in gpu_name: + # dejavu with microbenchmarks + # TODO: update to latest tuning with AVG + if MAX_SEQ_Q <= 384: + if MAX_SEQ_K <= 96: + config = {"num_stages": 4, "num_warps": 4, "BLOCK_N": 32, "BLOCK_M": 16} + else: + if MAX_SEQ_K <= 192: + if MAX_SEQ_Q <= 96: + config = { + "num_stages": 2, + "num_warps": 8, + "BLOCK_N": 128, + "BLOCK_M": 16, + } + else: + config = { + "num_stages": 4, + "num_warps": 4, + "BLOCK_N": 32, + "BLOCK_M": 16, + } + else: + if MAX_SEQ_Q <= 128: + config = { + "num_stages": 4, + "num_warps": 4, + "BLOCK_N": 32, + "BLOCK_M": 16, + } + else: + if MAX_SEQ_K <= 384: + config = { + "num_stages": 4, + "num_warps": 4, + "BLOCK_N": 32, + "BLOCK_M": 16, + } + else: + config = { + "num_stages": 1, + "num_warps": 4, + "BLOCK_N": 256, + "BLOCK_M": 32, + } + else: + if MAX_SEQ_K <= 768: + config = {"num_stages": 4, "num_warps": 4, "BLOCK_N": 16, "BLOCK_M": 64} + else: + config = {"num_stages": 1, "num_warps": 2, "BLOCK_N": 64, "BLOCK_M": 64} + else: + # default + config = { + "BLOCK_M": 64 if MAX_SEQ_Q > 1 and AVG_SEQ_Q >= 4096 else 16, + "BLOCK_N": 16 if MAX_SEQ_K < 128 and AVG_SEQ_Q <= 4096 else 64, + "num_warps": 4, + "num_stages": 3, + } + # print(config) + return config + + +# @triton_dejavu.jitcache( +# # this list is shorter, since it will be called only within one model +# check_keys=[ +# "MAX_SEQ_Q", +# "MAX_SEQ_K", +# "AVG_SEQ_Q", +# "AVG_SEQ_K", +# "stride_k_cache_3", +# "stride_v_cache_3", +# ], +# check_specialization=["num_seqs"], +# assume_const=[ +# "scale", +# "k_scale", +# "v_scale", +# "query_stride_1", +# "output_stride_1", +# "stride_k_cache_0", +# "stride_k_cache_1", +# "stride_k_cache_2", +# "stride_k_cache_4", +# "stride_v_cache_0", +# "stride_v_cache_1", +# "stride_v_cache_2", +# ], +# autotuner_args=["BLOCK_N", "BLOCK_M"], +# ) +@triton_dejavu.autotune( + config_space=triton_dejavu.ConfigSpace( + { + "BLOCK_N": [16, 32, 64, 128, 256, 512], + "BLOCK_M": [16, 32, 64, 128, 256, 512], + }, + num_warps=[2, 4, 8], + num_stages=[1, 2, 4, 6, 8], + # num_consumer_groups=[0, 2, 4], + # num_buffers_warp_spec=[0, 3, 6], + num_consumer_groups=[2, 4], + num_buffers_warp_spec=[3, 6], + conditions=[ + # ensure consistency for ws + lambda c: (c.num_consumer_groups !=0 and c.num_buffers_warp_spec != 0) \ + or (c.num_consumer_groups == 0 and c.num_buffers_warp_spec == 0), + ] + ), + # this list is longer, since it would be used for multiple models + key=[ + "MAX_SEQ_Q", + "MAX_SEQ_K", + "AVG_SEQ_Q", + "AVG_SEQ_K", + "num_query_heads", + "num_queries_per_kv", + "BLOCK_SIZE", + "HEAD_SIZE", + "HEAD_SIZE_PADDED", + "SLIDING_WINDOW", + "stride_k_cache_3", + "stride_v_cache_3", + ], + custom_data_storage=os.path.abspath( + os.path.join(os.path.dirname(__file__), "dejavu_data") + ), + use_cuda_graph=True, + use_bo=True, + search_max_search_t=360, + informed_fallback=informed_fallback_next, + prepare_informed_fallback=prepare_informed_fallback, + fallback_heuristic=fallback_heuristic_dt2, + ignore_dtypes=True, +) +# @triton.heuristics( +# { +# "BLOCK_M": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['BLOCK_M'], +# "BLOCK_N": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['BLOCK_N'], +# "num_warps": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['num_warps'], +# "num_stages": lambda args: prefill_heuristics_2d(args['MAX_SEQ_Q'], args['MAX_SEQ_K'], args['AVG_SEQ_Q'], args['AVG_SEQ_K'])['num_stages'], +# } +# ) +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + num_seqs: tl.int32, + # used as input to the autotuner/heuristics + MAX_SEQ_Q: tl.constexpr, + MAX_SEQ_K: tl.constexpr, + AVG_SEQ_Q: tl.constexpr, + AVG_SEQ_K: tl.constexpr, + # autotuner args + BLOCK_M: tl.constexpr, # int + BLOCK_N: tl.constexpr, # int +): + + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + BLOCK_Q = BLOCK_M // num_queries_per_kv + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + offs_n = tl.arange(0, BLOCK_N) + + # iterate through tiles (below the mask) + # The loop iterates only until the longest sequence. Due to causal + # masking, blocks beyond this prefix can be skipped. + for start_n in range(0, max_seq_prefix_len, BLOCK_N): + + start_n = tl.multiple_of(start_n, BLOCK_N) + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + (start_n + offs_n) // BLOCK_SIZE, + mask=(start_n + offs_n) < seq_len, + other=0, + ) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (offs_n[:, None] % BLOCK_SIZE) * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (offs_n[None, :] % BLOCK_SIZE) * stride_k_cache_1 + ) + + seq_offset_load = start_n + offs_n + load_mask = seq_offset_load < max_seq_prefix_len + + # K : (HEAD_SIZE_PADDED, BLOCK_N) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & load_mask[None, :], + other=0.0, + ) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_N, HEAD_SIZE_PADDED) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & load_mask[:, None], + other=0.0, + ) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = start_n + tl.arange(0, BLOCK_N) + + # seq_mask: (BLOCK_M, BLOCK_N) + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_N) + S = tl.zeros(shape=(BLOCK_M, BLOCK_N), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if SLIDING_WINDOW > 0: + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_N) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * blocks_per_segment, + min((segm_idx + 1) * blocks_per_segment, num_blocks), + ): + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1 + ) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if SLIDING_WINDOW > 0: + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + # [num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) + + # load segment maxima + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + # write result + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + avg_seqlen_q, + avg_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + MAX_SEQ_Q, + MAX_SEQ_K, + AVG_SEQ_Q, + AVG_SEQ_K, + alibi_slopes=None, + force_selection=None, # None, 2, 3 to select kernel +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + assert force_selection == 2 # only 2d is tuned for now + + block_size = v.shape[1] + assert ( + q.element_size() >= 2 or block_size >= 32 + ), "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + # MAX_SEQ_Q = triton.next_power_of_2(int(max_seqlen_q)) + # MAX_SEQ_K = triton.next_power_of_2(int(max_seqlen_k)) + # AVG_SEQ_Q = triton.next_power_of_2(int(avg_seqlen_q)) + # AVG_SEQ_K = triton.next_power_of_2(int(avg_seqlen_k)) + + # if batch contains a prefill + # if (max_seqlen_q > 1 or force_selection == 2) and force_selection != 3: + + grid = lambda META: ( + q.shape[0] // (META["BLOCK_M"] // num_queries_per_kv) + num_seqs, + num_kv_heads, + ) + + kernel_unified_attention_2d[grid]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_seqs, + MAX_SEQ_Q=MAX_SEQ_Q, + MAX_SEQ_K=MAX_SEQ_K, + AVG_SEQ_Q=AVG_SEQ_Q, + AVG_SEQ_K=AVG_SEQ_K, + ) + # else: + # BLOCK_M = 64 if max_seqlen_q > 1 and avg_seqlen_q >= 4096 else 16 + # BLOCK_Q = BLOCK_M // num_queries_per_kv + + # # Ideally we would launch with kernel with: + # # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # # However, it is slow to realize the query_lens on cpu. + # # Instead we use upper-bound: + # # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # # = floor(q.shape[0] / BLOCK_Q) + num_seqs + # total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + # # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # # value that showed good performance in tests + # NUM_SEGMENTS = 16 + + # segm_output = torch.empty( + # q.shape[0], + # num_query_heads, + # NUM_SEGMENTS, + # triton.next_power_of_2(head_size), + # dtype=torch.float32, + # device=q.device, + # ) + # segm_max = torch.empty( + # q.shape[0], + # num_query_heads, + # NUM_SEGMENTS, + # dtype=torch.float32, + # device=q.device, + # ) + # segm_expsum = torch.empty( + # q.shape[0], + # num_query_heads, + # NUM_SEGMENTS, + # dtype=torch.float32, + # device=q.device, + # ) + + # kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + # segm_output_ptr=segm_output, + # segm_max_ptr=segm_max, + # segm_expsum_ptr=segm_expsum, + # query_ptr=q, + # key_cache_ptr=k, + # value_cache_ptr=v, + # block_tables_ptr=block_table, + # seq_lens_ptr=seqused_k, + # alibi_slopes_ptr=alibi_slopes, + # scale=softmax_scale, + # k_scale=k_descale, + # v_scale=v_descale, + # softcap=softcap, + # num_query_heads=num_query_heads, + # num_queries_per_kv=num_queries_per_kv, + # block_table_stride=block_table.stride(0), + # query_stride_0=q.stride(0), + # query_stride_1=q.stride(1), + # BLOCK_SIZE=block_size, + # HEAD_SIZE=head_size, + # HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + # USE_ALIBI_SLOPES=use_alibi_slopes, + # USE_SOFTCAP=(softcap > 0), + # SLIDING_WINDOW=(1 + window_size[0]), + # stride_k_cache_0=k.stride(0), + # stride_k_cache_1=k.stride(1), + # stride_k_cache_2=k.stride(2), + # stride_k_cache_3=k.stride(3), + # stride_v_cache_0=v.stride(0), + # stride_v_cache_1=v.stride(1), + # stride_v_cache_2=v.stride(2), + # stride_v_cache_3=v.stride(3), + # query_start_len_ptr=cu_seqlens_q, + # BLOCK_Q=BLOCK_Q, + # num_seqs=num_seqs, + # BLOCK_M=BLOCK_M, + # NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + # ) + + # reduce_segments[(q.shape[0], num_query_heads)]( + # output_ptr=out, + # segm_output_ptr=segm_output, + # segm_max_ptr=segm_max, + # segm_expsum_ptr=segm_expsum, + # seq_lens_ptr=seqused_k, + # num_seqs=num_seqs, + # num_query_heads=num_query_heads, + # output_stride_0=out.stride(0), + # output_stride_1=out.stride(1), + # block_table_stride=block_table.stride(0), + # BLOCK_SIZE=block_size, + # HEAD_SIZE=head_size, + # HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + # query_start_len_ptr=cu_seqlens_q, + # BLOCK_Q=BLOCK_Q, + # NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + # ) diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_grid.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_grid.py new file mode 100644 index 000000000..936bf5a5a --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_grid.py @@ -0,0 +1,1004 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch + +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton + +import triton_dejavu +import os + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx(boundary_ptr, target_idx, num_seqs): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(boundary_ptr + mid) + if val <= target_idx: + left = mid + 1 + else: + right = mid + return left - 1 + + +@triton_dejavu.autotune( + config_space=triton_dejavu.ConfigSpace( + { + "BLOCK_M": [16, 32, 64, 128, 256, 512], + "TILE_SIZE": [16, 32, 64, 128, 256, 512], + }, + num_warps=[2, 4, 8], + num_stages=[1, 2, 4, 6, 8], + # num_consumer_groups=[0, 2, 4, 8], + # num_buffers_warp_spec=[0, 3, 6, 9], + # num_consumer_groups=[2, 4], + # num_buffers_warp_spec=[3, 6], + # conditions=[ + # # ensure consistency for ws + # lambda c: (c.num_consumer_groups != 0 and c.num_buffers_warp_spec != 0) \ + # or (c.num_consumer_groups == 0 and c.num_buffers_warp_spec == 0), + # ] + ), + # this list is longer, since it would be used for multiple models + key=[ + "num_query_heads", + "num_queries_per_kv", + "BLOCK_SIZE", + "HEAD_SIZE", + "HEAD_SIZE_PADDED", + "SLIDING_WINDOW", + "stride_k_cache_3", + "stride_v_cache_3", + "is_prefill", + ], + custom_data_storage=os.path.abspath( + os.path.join(os.path.dirname(__file__), "dejavu_data") + ), + use_cuda_graph=True, + use_bo=True, + # search_max_search_t=360, + # search_max_search_t=720, + # use_random_search=True, + search_max_search_t=1800, + # informed_fallback=informed_fallback_next, + # prepare_informed_fallback=prepare_informed_fallback, + # fallback_heuristic=fallback_heuristic_dt2, + ignore_dtypes=True, +) +@triton.heuristics( + {"BLOCK_Q": lambda args: args['BLOCK_M'] // args['num_queries_per_kv']}, +) +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + num_seqs: tl.int32, + seq_idx_offset, # int + block_q_seq_boundaries_ptr, # [num_prefills] or None + is_prefill: tl.constexpr, + max_q_block_idx: tl.int32, # int + q_block_iterations: tl.int32, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + BLOCK_Q: tl.constexpr, # int + BLOCK_M: tl.constexpr, # int +): + if tl.program_id(0) * q_block_iterations > max_q_block_idx: + return + + for q_block_global_idx in range(tl.program_id(0) * q_block_iterations, min((tl.program_id(0) + 1) * q_block_iterations, max_q_block_idx + 1)): + kv_head_idx = tl.program_id(1) + + if is_prefill: + seq_idx = find_seq_idx(block_q_seq_boundaries_ptr, q_block_global_idx, num_seqs) + q_block_start_idx = tl.load(block_q_seq_boundaries_ptr + seq_idx) + else: + seq_idx = q_block_global_idx + q_block_start_idx = seq_idx + seq_idx = seq_idx + seq_idx_offset + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + #if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + # return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # iterate through tiles + for j in range(0, num_tiles): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + S += qq_bias + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton_dejavu.autotune( + config_space=triton_dejavu.ConfigSpace( + { + "BLOCK_M": [16, 32, 64, 128, 256, 512], + "TILE_SIZE": [16, 32, 64, 128, 256, 512], + }, + num_warps=[2, 4, 8], + num_stages=[1, 2, 4, 6, 8], + # num_consumer_groups=[0, 2, 4, 8], + # num_buffers_warp_spec=[0, 3, 6, 9], + # num_consumer_groups=[2, 4], + # num_buffers_warp_spec=[3, 6], + # conditions=[ + # # ensure consistency for ws + # lambda c: (c.num_consumer_groups != 0 and c.num_buffers_warp_spec != 0) \ + # or (c.num_consumer_groups == 0 and c.num_buffers_warp_spec == 0), + # ] + ), + # this list is longer, since it would be used for multiple models + key=[ + "num_query_heads", + "num_queries_per_kv", + "BLOCK_SIZE", + "HEAD_SIZE", + "HEAD_SIZE_PADDED", + "SLIDING_WINDOW", + "stride_k_cache_3", + "stride_v_cache_3", + "NUM_SEGMENTS_PER_SEQ", + ], + custom_data_storage=os.path.abspath( + os.path.join(os.path.dirname(__file__), "dejavu_data") + ), + use_cuda_graph=True, + use_bo=True, + # search_max_search_t=360, + # search_max_search_t=720, + # use_random_search=True, + search_max_search_t=1800, + # informed_fallback=informed_fallback_next, + # prepare_informed_fallback=prepare_informed_fallback, + # fallback_heuristic=fallback_heuristic_dt2, + ignore_dtypes=True, +) +@triton.heuristics( + {"BLOCK_Q": lambda args: args['BLOCK_M'] // args['num_queries_per_kv']}, +) +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + num_seqs: tl.int32, + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + seq_idx_iterations: tl.int32, # int + BLOCK_Q: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 + BLOCK_M: tl.constexpr, # int +): + if tl.program_id(0) * seq_idx_iterations >= num_seqs: + return + + for seq_idx in range(tl.program_id(0) * seq_idx_iterations, min((tl.program_id(0) + 1) * seq_idx_iterations, num_seqs)): + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + #if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + # return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = offs_m // num_queries_per_kv + + query_offset_0 = seq_idx + query_pos #cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < 1, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - 1 + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + num_tiles = cdiv_fn(seq_len, TILE_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), + ): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < seq_len + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + S += qq_bias + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = (query_offset_0.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, + L, + mask=query_mask_0 & query_mask_1) + + +@triton_dejavu.autotune( + config_space=triton_dejavu.ConfigSpace( + { + "TILE_SIZE": [16, 32, 64, 128, 256, 512], + }, + num_warps=[2, 4, 8], + num_stages=[1, 2, 4, 6, 8], + # num_consumer_groups=[0, 2, 4, 8], + # num_buffers_warp_spec=[0, 3, 6, 9], + # # num_consumer_groups=[2, 4], + # # num_buffers_warp_spec=[3, 6], + # conditions=[ + # # ensure consistency for ws + # lambda c: (c.num_consumer_groups != 0 and c.num_buffers_warp_spec != 0) \ + # or (c.num_consumer_groups == 0 and c.num_buffers_warp_spec == 0), + # ] + ), + # this list is longer, since it would be used for multiple models + key=[ + "num_query_heads", + "HEAD_SIZE", + "HEAD_SIZE_PADDED", + "NUM_SEGMENTS_PER_SEQ", + ], + custom_data_storage=os.path.abspath( + os.path.join(os.path.dirname(__file__), "dejavu_data") + ), + use_cuda_graph=True, + use_bo=True, + # search_max_search_t=360, + # search_max_search_t=720, + # use_random_search=True, + search_max_search_t=1800, + # informed_fallback=informed_fallback_next, + # prepare_informed_fallback=prepare_informed_fallback, + # fallback_heuristic=fallback_heuristic_dt2, + ignore_dtypes=True, +) +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + seq_idx_iterations: tl.int32, # int + TILE_SIZE: tl.constexpr, # int +): + if tl.program_id(0) * seq_idx_iterations >= num_seqs: + return + + for seq_idx in range(tl.program_id(0) * seq_idx_iterations, min((tl.program_id(0) + 1) * seq_idx_iterations, num_seqs)): + query_head_idx = tl.program_id(1) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # load segment maxima + segm_offset = (seq_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) + segm_max = tl.load(segm_max_ptr + segm_offset, + mask=segm_mask, + other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, + mask=segm_mask, + other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + seq_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + # write result + output_offset = (seq_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED)) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + num_decodes, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + use_split_kv, + segm_output, + segm_max, + segm_expsum, + BLOCK_M_PREFILL, + BLOCK_Q_PREFILL, + BLOCK_M_DECODE, + BLOCK_Q_DECODE, + num_q_blocks, + block_q_seq_boundaries, + alibi_slopes=None, + qq_bias=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + use_qq_bias = qq_bias is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 32 + + LAUNCH_GRID_DIM0_2D_PREFILL = 32 + LAUNCH_GRID_DIM0_2D_DECODE = 32 + LAUNCH_GRID_DIM0_3D_DECODE = 4 + LAUNCH_GRID_DIM0_3D_REDUCE = 4 + + # prefill + if num_seqs > num_decodes: + kernel_unified_attention_2d[( + LAUNCH_GRID_DIM0_2D_PREFILL, #num_q_blocks, + num_kv_heads, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_seqs - num_decodes, + seq_idx_offset=num_decodes, + block_q_seq_boundaries_ptr=block_q_seq_boundaries, + is_prefill=True, + max_q_block_idx=num_q_blocks-1, + q_block_iterations=(num_q_blocks + LAUNCH_GRID_DIM0_2D_PREFILL - 1) // LAUNCH_GRID_DIM0_2D_PREFILL + # tunable parameters + # BLOCK_M=BLOCK_M_PREFILL, + # BLOCK_Q=BLOCK_Q_PREFILL, + # TILE_SIZE=TILE_SIZE_PREFILL, + ) + + # decode + if num_decodes > 0: + # select between 2d and 3d (split-kv) kernels + if not use_split_kv: + kernel_unified_attention_2d[( + LAUNCH_GRID_DIM0_2D_DECODE, #num_decodes, + num_kv_heads, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_decodes, + seq_idx_offset=0, + block_q_seq_boundaries_ptr=None, + is_prefill=False, + max_q_block_idx=num_decodes-1, + q_block_iterations=(num_decodes + LAUNCH_GRID_DIM0_2D_DECODE - 1) // LAUNCH_GRID_DIM0_2D_DECODE + # tunable parameters + # BLOCK_M=BLOCK_M_DECODE, + # BLOCK_Q=BLOCK_Q_DECODE, + # TILE_SIZE=TILE_SIZE_DECODE, + ) + else: + # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # value that showed good performance in tests + NUM_SEGMENTS = 16 + +# segm_output = torch.empty( +# num_decodes, +# num_query_heads, +# NUM_SEGMENTS, +# triton.next_power_of_2(head_size), +# dtype=torch.float32, +# device=q.device, +# ) +# segm_max = torch.empty( +# num_decodes, +# num_query_heads, +# NUM_SEGMENTS, +# dtype=torch.float32, +# device=q.device, +# ) +# segm_expsum = torch.empty( +# num_decodes, +# num_query_heads, +# NUM_SEGMENTS, +# dtype=torch.float32, +# device=q.device, +# ) + + kernel_unified_attention_3d[( + LAUNCH_GRID_DIM0_3D_DECODE, #num_decodes, + num_kv_heads, + NUM_SEGMENTS + )]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_decodes, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + seq_idx_iterations=(num_decodes + LAUNCH_GRID_DIM0_3D_DECODE - 1) // LAUNCH_GRID_DIM0_3D_DECODE + # tunable parameters + # BLOCK_Q=BLOCK_Q_DECODE, + # BLOCK_M=BLOCK_M_DECODE, + # TILE_SIZE=TILE_SIZE_DECODE, + ) + reduce_segments[( + LAUNCH_GRID_DIM0_3D_REDUCE, #num_decodes, + num_query_heads + )]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + seq_idx_iterations=(num_decodes + LAUNCH_GRID_DIM0_3D_REDUCE - 1) // LAUNCH_GRID_DIM0_3D_REDUCE + # tunable parameters + # TILE_SIZE=TILE_SIZE_DECODE, + ) diff --git a/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_newtiles.py b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_newtiles.py new file mode 100644 index 000000000..c10e38b08 --- /dev/null +++ b/ibm-triton-lib/ibm_triton_lib/kernels/triton_unified_newtiles.py @@ -0,0 +1,776 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch + +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, + BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # iterate through tiles + for j in range(0, num_tiles): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), + ): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + #if kv_head_idx == 0: + # print(f"\nq_block_global_idx={q_block_global_idx} segm_idx={segm_idx} j={j} : L={L} M={M}\n") # acc={acc}\n") + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = (query_offset_0.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, + L, + mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + TILE_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, + BLOCK_Q, False) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # load segment maxima + segm_offset = (query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) + segm_max = tl.load(segm_max_ptr + segm_offset, + mask=segm_mask, + other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, + mask=segm_mask, + other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + # write result + output_offset = (query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED)) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, + force_selection=None, # None, 2, 3 to select kernel +): + + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = 16 + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 32 + + # if batch contains a prefill + # if (max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128) or force_selection == 2 and force_selection != 3: + if force_selection == 2: + kernel_unified_attention_2d[( + total_num_q_blocks, + num_kv_heads, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_seqs, + # tunable parameters + # BLOCK_M=BLOCK_M, + # BLOCK_Q=BLOCK_Q, + # TILE_SIZE=TILE_SIZE_DECODE, + ) + elif force_selection == 3: + # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # value that showed good performance in tests + NUM_SEGMENTS = 16 + + segm_output = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + + kernel_unified_attention_3d[( + total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + else: + raise RuntimeError("currently, we need to force a kernel selection") \ No newline at end of file diff --git a/scripts/bench_vllm_latency_range.py b/scripts/bench_vllm_latency_range.py index a12716ae5..f711316f7 100644 --- a/scripts/bench_vllm_latency_range.py +++ b/scripts/bench_vllm_latency_range.py @@ -20,6 +20,28 @@ import sys import torch from datetime import datetime +from itertools import zip_longest, repeat, chain, product + + +# ================= SETUP + +# selected_batch_sizes = [1] # [4, 16, 32] #,128] +# selected_input_lengths = [500] # , 1000, 1500, 2000, 4000, 8000, 16000] +# selected_output_lengths = [10, 100, 200, 400, 800, 1600, 3200, 6400, 12800] +# selected_input_lengths = [64, 128, 512, 1024, 2048, 4096] +# selected_input_lengths = [64, 128, 512, 1024, 2048, 4096, 8192, 31500] +# selected_output_lengths = [1] +selected_batch_sizes = [1, 2, 4, 8, 16, 32, 64] +selected_input_lengths = [128] +selected_output_lengths = [32, 128, 256] + +# use_cross_product = False +use_cross_product = True + +warmup_iterations = 3 +iterations = 5 + +# ================= def create_dir_if_not_exist_recursive(path, mode=0o777): @@ -42,23 +64,19 @@ def create_dir_if_not_exist(path, mode=0o777): print(f"can't set permission of directory {path}: {e}") -if len(sys.argv) < 4: - print(f"Usage: {sys.argv[0]} ") +if len(sys.argv) < 5: + print(f"Usage: {sys.argv[0]} ") + exit(-1) -selected_batch_sizes = [1] # [4, 16, 32] #,128] -selected_input_lengths = [500] # , 1000, 1500, 2000, 4000, 8000, 16000] -selected_output_lengths = [10, 100, 200, 400, 800, 1600, 3200, 6400, 12800] gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") # model = "/model/llama3.1-8b/instruct/" model = sys.argv[1] -testcase_name = sys.argv[2] -result_path = os.path.abspath(sys.argv[3]) +tp = int(sys.argv[2]) +testcase_name = sys.argv[3] +result_path = os.path.abspath(sys.argv[4]) -# max_rounds = 128 -max_rounds = 64 -max_num_prompts = 1000 timestamp_f = datetime.now().strftime("%Y-%m-%d_%H%M") @@ -75,21 +93,24 @@ def create_dir_if_not_exist(path, mode=0o777): print(f"can't find benchmark script benchmark_latency.py") exit(-1) -# Assisted by watsonx Code Assistant -from itertools import zip_longest - -zipped_lists = list( - zip_longest( - selected_batch_sizes, - selected_input_lengths, - selected_output_lengths, - fillvalue=None, +if use_cross_product: + zipped_lists = list(product(selected_batch_sizes, selected_input_lengths, selected_output_lengths)) +else: + max_length = max(len(selected_batch_sizes), len(selected_input_lengths), len(selected_output_lengths)) + zipped_lists = list( + zip_longest( + chain(selected_batch_sizes, + repeat(selected_batch_sizes[-1], times=max_length-len(selected_batch_sizes))), + chain(selected_input_lengths, + repeat(selected_input_lengths[-1], times=max_length-len(selected_input_lengths))), + chain(selected_output_lengths, + repeat(selected_output_lengths[-1], times=max_length-len(selected_output_lengths))), + fillvalue=None, + ) ) -) - print(zipped_lists) - +start_time = datetime.now() for bs, il, ol in zipped_lists: print( f"====== Measuring batch_size {bs}, input length {il}, output length {ol} =====" @@ -99,7 +120,13 @@ def create_dir_if_not_exist(path, mode=0o777): f"VLLM_USE_V1=1 python {bench_script} " f"--model {model} " f"--input-len {il} --output-len {ol} --batch-size {bs} " - f"--output-json {json_file_name}" + f"--output-json {json_file_name} " + f"--num-iters-warmup {warmup_iterations} " + f"--num-iters {iterations} " + f"--tensor-parallel {tp} " + f"--enable-chunked-prefill " + f"--max-num-batched-tokens 16384 " + # f"-O.full_cuda_graph=true" ) print(cmd) rv = os.system(cmd) @@ -107,5 +134,7 @@ def create_dir_if_not_exist(path, mode=0o777): print(f"benchmark command returned {rv}, stopping...") exit(rv) +end_time = datetime.now() print(f"results stored in: {result_dir}") os.system(f"ls -alh {result_dir}") +print(f"Benchmark time: {end_time-start_time}") diff --git a/scripts/bench_vllm_user_range.py b/scripts/bench_vllm_user_range.py index c93c88842..212930fa7 100644 --- a/scripts/bench_vllm_user_range.py +++ b/scripts/bench_vllm_user_range.py @@ -40,29 +40,45 @@ def create_dir_if_not_exist(path, mode=0o777): except PermissionError as e: print(f"can't set permission of directory {path}: {e}") +if len(sys.argv) < 4: + print(f"Usage: {sys.argv[0]} ") + exit(-1) num_users_to_test = [1, 2, 4, 8, 16, 32, 64, 128] gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") # model = "/model/llama3.1-8b/instruct/" model = sys.argv[1] -model_path = f"/models/{model}/" testcase_name = sys.argv[2] +result_path = os.path.abspath(sys.argv[3]) # max_rounds = 128 max_rounds = 64 max_num_prompts = 1000 +bench_repetitions = 3 + timestamp_f = datetime.now().strftime("%Y-%m-%d_%H%M") -# result_dir = f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}" -result_dir = ( - f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" -) +# result_dir = ( +# f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" +# ) +model_print_path = model.replace('/','-') +if model_print_path[0:2] == './': + model_print_path = model_print_path[2:] +result_dir = f"{result_path}/{model_print_path}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" + +bench_script = "/workspace/benchmarks/benchmark_serving.py" +if not os.path.isfile(bench_script): + bench_script = "./vllm-triton-backend/vllm/benchmarks/benchmark_serving.py" + if not os.path.isfile(bench_script): + print(f"can't find benchmark script benchmark_serving.py") + exit(-1) # os.system(f"mkdir -p {result_dir}") create_dir_if_not_exist_recursive(result_dir) +start_time = datetime.now() for max_concurrency in num_users_to_test: num_prompts = ( max_num_prompts @@ -70,18 +86,27 @@ def create_dir_if_not_exist(path, mode=0o777): else int(max_rounds * max_concurrency) ) cmd = ( - f"VLLM_USE_V1=1 python /workspace/benchmarks/benchmark_serving.py " - f"--model {model_path} " + f"VLLM_USE_V1=1 python {bench_script} " + f"--model {model} " f"--dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json " f"--save-result --result-dir {result_dir} --max-concurrency {max_concurrency} " f"--percentile-metrics ttft,tpot,itl,e2el --metric-percentiles 20,50,80,99 " f"--num-prompts {num_prompts} " + f"--port 8803" ) - print(cmd) - rv = os.system(cmd) + for i in range(bench_repetitions): + print( + f"====== Measuring max concurrency {max_concurrency} with {num_prompts} prompts; repetition {i} =====" + ) + print(cmd) + rv = os.system(cmd) + if rv != 0: + print(f"benchmark command returned {rv}, stopping...") + break if rv != 0: - print(f"benchmark command returned {rv}, stopping...") break +end_time = datetime.now() print(f"results stored in: {result_dir}") os.system(f"ls -alh {result_dir}") +print(f"Benchmark time: {end_time-start_time}") diff --git a/scripts/bench_vllm_user_range_random.py b/scripts/bench_vllm_user_range_random.py new file mode 100644 index 000000000..5f16c45b6 --- /dev/null +++ b/scripts/bench_vllm_user_range_random.py @@ -0,0 +1,118 @@ +# /******************************************************************************* +# * Copyright 2025 IBM Corporation +# * +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# *******************************************************************************/ +# + +import os +import sys +import torch +from datetime import datetime + + +def create_dir_if_not_exist_recursive(path, mode=0o777): + norm_path = os.path.normpath(path) + paths_l = norm_path.split(os.sep) + path_walked = f"{os.sep}" + for p in paths_l: + if len(p) == 0: + continue + path_walked = os.path.join(path_walked, p) + create_dir_if_not_exist(path_walked, mode) + + +def create_dir_if_not_exist(path, mode=0o777): + if not os.path.exists(path): + os.mkdir(path) + try: + os.chmod(path, mode) + except PermissionError as e: + print(f"can't set permission of directory {path}: {e}") + +if len(sys.argv) < 6: + print(f"Usage: {sys.argv[0]} ") + exit(-1) + +# num_users_to_test = [1, 2, 4, 8, 16, 32, 64, 128] +num_users_to_test = [1, 2, 4, 8, 16, 32, 64] +gpu_name = torch.cuda.get_device_name().replace(" ", "_").replace("/", "_") + +# model = "/model/llama3.1-8b/instruct/" +model = sys.argv[1] +input_len = sys.argv[2] +output_len = sys.argv[3] +testcase_name = sys.argv[4] +result_path = os.path.abspath(sys.argv[5]) + +# max_rounds = 128 +# max_rounds = 64 +# max_rounds = 16 +# max_num_prompts = 1000 +min_num_prompts = 16 + +bench_repetitions = 3 + +timestamp_f = datetime.now().strftime("%Y-%m-%d_%H%M") + +# result_dir = ( +# f"/results/{model.replace('/','-')}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" +# ) +model_print_path = model.replace('/','-') +if model_print_path[0:2] == './': + model_print_path = model_print_path[2:] +result_dir = f"{result_path}/{model_print_path}/{gpu_name}/{testcase_name}/exp_{timestamp_f}/" + +bench_script = "/workspace/benchmarks/benchmark_serving.py" +if not os.path.isfile(bench_script): + bench_script = "./vllm-triton-backend/vllm/benchmarks/benchmark_serving.py" + if not os.path.isfile(bench_script): + print(f"can't find benchmark script benchmark_serving.py") + exit(-1) + +# os.system(f"mkdir -p {result_dir}") +create_dir_if_not_exist_recursive(result_dir) + +start_time = datetime.now() +for max_concurrency in num_users_to_test: + # num_prompts = ( + # max_num_prompts + # if max_num_prompts // max_concurrency < max_rounds + # else int(max_rounds * max_concurrency) + # ) + num_prompts = int(max(min_num_prompts, 2*max_concurrency)) + cmd = ( + f"VLLM_USE_V1=1 python {bench_script} " + f"--model {model} " + f"--dataset-name random --random-input-len={input_len} --random-output-len={output_len} --ignore-eos " + f"--save-result --result-dir {result_dir} --max-concurrency {max_concurrency} " + f"--percentile-metrics ttft,tpot,itl,e2el --metric-percentiles 20,50,80,99 " + f"--num-prompts {num_prompts} " + f"--port 8803" + ) + for i in range(bench_repetitions): + print( + f"====== Measuring max concurrency {max_concurrency} with {num_prompts} prompts; repetition {i} =====" + ) + print(cmd) + rv = os.system(cmd) + if rv != 0: + print(f"benchmark command returned {rv}, stopping...") + break + if rv != 0: + break + +end_time = datetime.now() +print(f"results stored in: {result_dir}") +os.system(f"ls -alh {result_dir}") +print(f"Benchmark time: {end_time-start_time}") diff --git a/scripts/benchmark.py b/scripts/benchmark.py index c1799b268..6d7881855 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -68,6 +68,15 @@ class Implementation(Enum): UNF_TRITON_2D = 11 UNF_TRITON_AUTO = 12 PYTORCH_NATIVE = 13 + TRITON_TUNED = 14 + TRITON_FALLBACK = 15 + UNF_TRITON_2D_SIMPLE = 16 + NT_UNF_TRITON_3D = 17 + NT_UNF_TRITON_2D = 18 + NT_UNF_TRITON_AUTO = 19 + UNF_TRITON_2D_TUNED = 20 + GRID_TRITON_3D = 21 + GRID_TRITON_2D = 22 class BenchmarkMode(Enum): @@ -97,7 +106,8 @@ class BatchComposition(Enum): SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] -PREFIX_PREFILL_BATCH_COMPOSITION = [BatchComposition.ALTERNATING] +PREFIX_PREFILL_BATCH_COMPOSITION = [BatchComposition.DEC_PRE] +RESERVE_INPUT_TOKEN_LENGTH = [None] HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 # head_size * head_numbers = hidden_size @@ -118,6 +128,13 @@ class BatchComposition(Enum): STATE_N_GROUPS = [1] HAS_INITIAL_STATE = [True] +MOE_NUM_EXPERTS = [8] +MOE_N = [14336] # intermediate size of mixtral-8x7b +MOE_K = [4096] # for mixtral-8x7b +TP_FACTOR = [1, 2] +MOE_TOP_K = [2] # for mixtral-8x7b, mixtral-8x22b + + IMPLEMENTATION_UT = [ Implementation.TRITON_2D, Implementation.TRITON_3D, @@ -167,6 +184,13 @@ class BatchComposition(Enum): "MAX_VALUES", "STATE_DIM", "STATE_N_GROUPS", + "HAS_INITIAL_STATE", + "MOE_NUM_EXPERTS", + "MOE_N", + "MOE_K", + "TP_FACTOR", + "MOE_TOP_K", + "RESERVE_INPUT_TOKEN_LENGTH", ] # "BENCHMARK_MODES", "IMPLEMENTATION_UT" ] debug_env_vars = [ @@ -187,8 +211,12 @@ class BatchComposition(Enum): import json envfile_path = os.path.abspath(envfile_name) - print(f"\nApplied test config: {envfile_path}") + if not os.path.isfile(envfile_path): + raise RuntimeError(f"Test config file {envfile_path} does not exist.") env_setting = dotenv_values(envfile_path) + if len(env_setting) == 0: + raise RuntimeError(f"Test config file {envfile_path} does not contain valid configs.") + print(f"\nApplied test config: {envfile_path}") # filter allowed, convert all to lists env_setting_filtered = { k: json.loads(env_setting[k]) for k in test_setup_vars if k in env_setting @@ -982,6 +1010,7 @@ def test_prefill_vllm_v0_attention( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("implementation", IMPLEMENTATION_UT) @pytest.mark.parametrize("max_value", MAX_VALUES) +@pytest.mark.parametrize("reserved_query_length", RESERVE_INPUT_TOKEN_LENGTH) @pytest.mark.parametrize("benchmark_mode", BENCHMARK_MODES) @torch.inference_mode() def test_prefix_vllm_v1_attention( @@ -1003,6 +1032,7 @@ def test_prefix_vllm_v1_attention( seed, implementation, max_value, + reserved_query_length, benchmark_mode, ): my_id = request.node.nodeid.split("::")[-1] @@ -1011,6 +1041,9 @@ def test_prefix_vllm_v1_attention( realistic_prompt_mode = len(prompt_pattern) > 1 gqa_mode = num_heads[0] != num_heads[1] + reserved_query_length = None if reserved_query_length in [None, 'none', -1, 0] else int(reserved_query_length) + skip_ref_impl = True if reserved_query_length is not None else False + if torch.cuda.get_device_capability()[0] < 8: # reduce operations are not supported (?) pytest.skip() @@ -1022,14 +1055,28 @@ def test_prefix_vllm_v1_attention( Implementation.TRITON_2D, Implementation.UNF_TRITON_3D, Implementation.UNF_TRITON_2D, - Implementation.UNF_TRITON_AUTO, + Implementation.UNF_TRITON_2D_SIMPLE, + # Implementation.UNF_TRITON_AUTO, + Implementation.NT_UNF_TRITON_3D, + Implementation.NT_UNF_TRITON_2D, + # Implementation.NT_UNF_TRITON_AUTO, + Implementation.UNF_TRITON_2D_TUNED, + Implementation.GRID_TRITON_2D, + Implementation.GRID_TRITON_3D, ]: pytest.skip() + if implementation == Implementation.GRID_TRITON_3D and decode_share != 1.0: + pytest.skip("not supported") + + + if batch_composition == BatchComposition.ALTERNATING and implementation == Implementation.FLASH_ATTN: + pytest.skip("not supported") + # TODO: Error: "Offset increment outside graph capture" # for triton and flash_attn - if benchmark_mode == BenchmarkMode.CUDA_GRAPHS: - pytest.skip("not supported") + # if benchmark_mode == BenchmarkMode.CUDA_GRAPHS: + # pytest.skip("not supported") # TODO # RTOL = 0 @@ -1165,7 +1212,10 @@ def test_prefix_vllm_v1_attention( inner_exception = None try: - query = torch.empty(total_query_tokens, num_query_heads, head_size, dtype=dtype) + query_tensor_num_tokens_reserved = total_query_tokens + if reserved_query_length is not None: + query_tensor_num_tokens_reserved = reserved_query_length + query = torch.empty(query_tensor_num_tokens_reserved, num_query_heads, head_size, dtype=dtype) query.uniform_(-max_value, max_value) key = torch.empty(total_token_num, num_kv_heads, head_size, dtype=dtype) @@ -1224,41 +1274,42 @@ def test_prefix_vllm_v1_attention( slot_mapping_lst.extend(slot_mapping_i) slot_mapping_t = torch.tensor(slot_mapping_lst, dtype=torch.int) - ref_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping_t, - block_size, - total_token_num, - ) + if not skip_ref_impl: + ref_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping_t, + block_size, + total_token_num, + ) - ref_output = ref_prefix_prefill( - query, - num_queries_per_kv, - key_cache, - value_cache, - key, - value, - block_table_t, - b_seq_lens, - b_ctx_lens, - b_query_lens, - b_start_loc, - batch_size, - scale, - dtype, - ) - # ref_output = ref_paged_attn( - # query, - # key_cache, - # value_cache, - # b_query_lens, - # b_ctx_lens, - # block_table_t, - # scale, - # ) + ref_output = ref_prefix_prefill( + query, + num_queries_per_kv, + key_cache, + value_cache, + key, + value, + block_table_t, + b_seq_lens, + b_ctx_lens, + b_query_lens, + b_start_loc, + batch_size, + scale, + dtype, + ) + # ref_output = ref_paged_attn( + # query, + # key_cache, + # value_cache, + # b_query_lens, + # b_ctx_lens, + # block_table_t, + # scale, + # ) if implementation == Implementation.FLASH_ATTN: from callers import FlashAttnPrefixPrefillCaller as Caller @@ -1274,8 +1325,22 @@ def test_prefix_vllm_v1_attention( from callers import UnifiedTriton3dAttentionCaller as Caller elif implementation == Implementation.UNF_TRITON_2D: from callers import UnifiedTriton2dAttentionCaller as Caller + elif implementation == Implementation.UNF_TRITON_2D_SIMPLE: + from callers import SimpleUnifiedTriton2dAttentionCaller as Caller elif implementation == Implementation.UNF_TRITON_AUTO: from callers import UnifiedTritonAutoAttentionCaller as Caller + elif implementation == Implementation.NT_UNF_TRITON_3D: + from callers import NewTilesUnifiedTriton3dAttentionCaller as Caller + elif implementation == Implementation.NT_UNF_TRITON_2D: + from callers import NewTilesUnifiedTriton2dAttentionCaller as Caller + elif implementation == Implementation.NT_UNF_TRITON_AUTO: + from callers import NewTilesUnifiedTritonAutoAttentionCaller as Caller + elif implementation == Implementation.UNF_TRITON_2D_TUNED: + from callers import TunedUnifiedTriton2dAttentionCaller as Caller + elif implementation == Implementation.GRID_TRITON_3D: + from callers import GridTriton3dAttentionCaller as Caller + elif implementation == Implementation.GRID_TRITON_2D: + from callers import GridTriton2dAttentionCaller as Caller if Caller.requires_allocated_output: output = torch.empty_like(query) @@ -1339,10 +1404,12 @@ def test_prefix_vllm_v1_attention( # captured += l # + '|' captured += l + " " # compare - if enforce_numerical_correctness: + if enforce_numerical_correctness and not skip_ref_impl: # for better reports triton.testing.assert_close(ref_output, output, atol=ATOL, rtol=RTOL) allclose_pass = True + elif skip_ref_impl: + allclose_pass = 'skipped' else: allclose_pass = torch.allclose(ref_output, output, atol=ATOL, rtol=RTOL) @@ -1418,7 +1485,9 @@ def test_prefix_vllm_v1_attention( "num_blocks": num_blocks, "dtype": dtype, "max_value": max_value, + "query_tensor_num_tokens_reserved": query_tensor_num_tokens_reserved, "realistic_prompt_mode": realistic_prompt_mode, + "batch_composition": batch_composition, "gqa_mode": gqa_mode, "prompt_pattern": prompt_pattern, "implementation": implementation, @@ -1705,6 +1774,223 @@ def generate_dummy_data(batch_size): raise e +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seqlen", SEQUENCE_LENGTHS) +@pytest.mark.parametrize("n", MOE_N) +@pytest.mark.parametrize("k", MOE_K) +@pytest.mark.parametrize("e", MOE_NUM_EXPERTS) +@pytest.mark.parametrize("tp", TP_FACTOR) +@pytest.mark.parametrize("topk", MOE_TOP_K) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("max_value", MAX_VALUES) +@pytest.mark.parametrize("implementation", IMPLEMENTATION_UT) +@pytest.mark.parametrize("benchmark_mode", BENCHMARK_MODES) +def test_fused_moe( + capsys, + request, + batch_size, + seqlen, + n: int, + k: int, + e: int, + tp: int, + topk: int, + dtype: torch.dtype, + seed, + max_value, + implementation, + benchmark_mode, +): + # based on: https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_moe.py + from vllm.model_executor.layers.activation import SiluAndMul + + my_id = request.node.nodeid.split("::")[-1] + my_name = my_id.split("[")[0] + my_instance = my_id.split("[")[1][:-1] + + if implementation not in [Implementation.TRITON_TUNED, Implementation.TRITON_FALLBACK]: + pytest.skip() + + def torch_moe(a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + from ibm_triton_lib.kernels import fused_moe + # my_experts = TritonExperts() + # fused_moe = my_experts.apply + + torch.manual_seed(seed) + tdev = torch.device(device) + torch.cuda.set_device(tdev) + # m = batch_size * seqlen + num_tokens = batch_size * seqlen + m = num_tokens + n = int(n//tp) + + # ATOL = 1e-2 + # TODO + ATOL = max(1e-2, 2 * max_value) + RTOL = 0 + + a = None + w1 = None + w2 = None + score = None + torch_output = None + triton_output = None + + inner_exception = None + try: + + a = torch.randn((m, k), device=tdev, dtype=dtype).normal_(mean=0.0, std=0.5 * max_value) + # w1 = torch.randn((e, n, k), device=tdev, dtype=dtype).normal_(mean=0.0, std=0.5 * max_value) + # w2 = torch.randn((e, k, n//2), device=tdev, dtype=dtype).normal_(mean=0.0, std=0.5 * max_value) + w1 = torch.randn((e, 2 * n, k), device=tdev, dtype=dtype).normal_(mean=0.0, std=0.5 * max_value) + w2 = torch.randn((e, k, n), device=tdev, dtype=dtype).normal_(mean=0.0, std=0.5 * max_value) + score = torch.randn((m, e), device=tdev, dtype=dtype) + + input_gating = torch.empty(num_tokens, e, dtype=torch.float32, device=tdev) + + if enforce_numerical_correctness: + torch_output = torch_moe(a, w1, w2, score, topk) + assert torch_output is not None + """ + from fused_moe.py + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + """ + + + use_default_config = True if implementation == Implementation.TRITON_FALLBACK else False + triton_output = fused_moe(a, w1, w2, input_gating, topk, + renormalize=True, use_default_config=use_default_config) #inplace=True ? + assert triton_output is not None + + captured = '' + if capsys is not None: + captured_raw = capsys.readouterr() # returns stdout, stderr + for l in captured_raw: + if len(l) > 0: + # captured += l # + '|' + captured += l + ' ' + + # compare + allclose_pass = float('nan') + if enforce_numerical_correctness: + triton.testing.assert_close(torch_output, triton_output, atol=ATOL, rtol=RTOL) + allclose_pass = True + + call_func_under_test = lambda: fused_moe(a, w1, w2, input_gating, topk, + renormalize=True, inplace=True, + use_default_config=use_default_config) + + # benchmark only correct results + if do_benchmarks: + if my_name not in pytest.global_pds: + pytest.global_pds[my_name] = pd.DataFrame() + + # equals to defaults + warmup_rep = 25 + bench_rep = 100 + ms, min_ms, max_ms = measure_benchmarks( + benchmark_mode, call_func_under_test, warmup_rep, bench_rep + ) + + record = { + "batch_size": batch_size, + "seqlen": seqlen, + "num_tokens": num_tokens, # redundant? + "N": n, + "K": k, + "E": e, + "TP": tp, + "topk": topk, + "max_value": max_value, + "dtype": dtype, + "implementation": implementation, + "ms": ms, + "min_ms": min_ms, + "max_ms": max_ms, + "benchmark_mode": benchmark_mode, + "allclose_pass": allclose_pass, + "ATOL": ATOL, + "RTOL": RTOL, + # "proton_count": proton_count, + # "proton_ns": proton_ns, + # "proton_util_compute": proton_util_compute, + # "proton_util_bw": proton_util_bw, + "captured": captured, + } + + if add_triton_dejavu_envs: + dejavu_envs = {} + _skip_dejavu_envs = [ + "_TRITON_DEJAVU_DETERMINED_CUDA_VERSION", + "DEBUG", + "STORAGE", + ] + for env in os.environ.keys(): + if "TRITON_DEJAVU_" in env: + if any([skip_s in env for skip_s in _skip_dejavu_envs]): + continue + dejavu_envs[env] = os.environ[env] + record.update(dejavu_envs) + + pytest.global_pds[my_name] = pd.concat( + [pytest.global_pds[my_name], pd.Series(record).to_frame().T] + ).reset_index(drop=True) + + if pytest.global_pd_file_prefix is not None: + filename = os.path.abspath( + f"{pytest.global_pd_file_prefix}/{my_name}.csv" + ) + write_df_and_chmod(pytest.global_pds[my_name], filename) + + except Exception as e: + print(e) + inner_exception = e + finally: + # cleanup memory + try: + del a + del w1 + del w2 + del score + del triton_output + del torch_output + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception as e: + print(e) + # pass + finally: + if inner_exception is not None: + raise inner_exception + + + + def measure_benchmarks( benchmark_mode, call_func_under_test, warmup_rep=25, bench_rep=100 ): @@ -1801,7 +2087,7 @@ def write_df_and_chmod(df, filename, mode=0o777): timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") if STORE_TEST_RESULT_PATH is not None: - gpu_path = os.path.join(STORE_TEST_RESULT_PATH, gpu_name) + gpu_path = os.path.join(os.path.abspath(STORE_TEST_RESULT_PATH), gpu_name) gloabl_pd_file_prefix = os.path.join(gpu_path, timestamp) create_dir_if_not_exist_recursive(gloabl_pd_file_prefix) else: diff --git a/scripts/callers/__init__.py b/scripts/callers/__init__.py index 022e3a2e5..35ba469cf 100644 --- a/scripts/callers/__init__.py +++ b/scripts/callers/__init__.py @@ -56,4 +56,15 @@ UnifiedTriton2dAttentionCaller, UnifiedTriton3dAttentionCaller, UnifiedTritonAutoAttentionCaller, + SimpleUnifiedTriton2dAttentionCaller, + TunedUnifiedTriton2dAttentionCaller, +) +from .unified_triton_newtiles import ( + NewTilesUnifiedTriton2dAttentionCaller, + NewTilesUnifiedTriton3dAttentionCaller, + NewTilesUnifiedTritonAutoAttentionCaller, +) +from .grid_triton import ( + GridTriton2dAttentionCaller, + GridTriton3dAttentionCaller, ) diff --git a/scripts/callers/flash_attn.py b/scripts/callers/flash_attn.py index 6a63d778d..8db659f52 100644 --- a/scripts/callers/flash_attn.py +++ b/scripts/callers/flash_attn.py @@ -204,7 +204,7 @@ def call_and_process_output(): block_table=block_tables, # window_size=(-1, 1), # softcap=0, - # fa_version=2, # TODO + fa_version=3, # TODO ) return call_and_process_output diff --git a/scripts/callers/grid_triton.py b/scripts/callers/grid_triton.py new file mode 100644 index 000000000..8ae898202 --- /dev/null +++ b/scripts/callers/grid_triton.py @@ -0,0 +1,193 @@ +# /******************************************************************************* +# * Copyright 2025 IBM Corporation +# * +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# *******************************************************************************/ +# + +import torch +import triton + +from ibm_triton_lib.kernels import unified_attention_grid +from .base import PrefixPrefillCaller + + +class GridTriton3dAttentionCaller(PrefixPrefillCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=3, + ): + """ + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_blocks, block_size, num_kv_heads, head_size] + v_cache = [num_blocks, block_size, num_kv_heads, head_size] + Returns: + shape = [num_tokens, num_heads, head_size] + """ + + max_query_len = query_lens.max() + max_seqlen = seq_lens.max() + + avg_seqlen_q = query_lens.to(torch.float).mean() + avg_seqlen_k = seq_lens.to(torch.float).mean() + + block_size = value.shape[1] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_kv_heads = key.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = query.shape[2] + + query_lens = torch.diff(start_loc) + if max_query_len == 1: + num_decodes = len(seq_lens) + else: + num_decodes = torch.argmax((query_lens != 1).int()).item() + + BLOCK_M_PREFILL = 64 + BLOCK_M_DECODE = 16 + BLOCK_Q_PREFILL = BLOCK_M_PREFILL * num_kv_heads // num_query_heads + BLOCK_Q_DECODE = BLOCK_M_DECODE * num_kv_heads // num_query_heads + + block_q_seq_boundaries = torch.cumsum(torch.cat([torch.tensor([0], dtype=query_lens.dtype, device=query_lens.device), torch.ceil(query_lens[num_decodes:] / BLOCK_Q_PREFILL).to(torch.int)]), dim=0) + num_q_blocks = block_q_seq_boundaries[-1].item() + + # use_split_kv = (num_q_blocks * self.num_heads_kv < 128) + use_split_kv = force_selection == 3 + + NUM_SEGMENTS=16 + + if use_split_kv: + segm_output = torch.empty( + num_decodes, + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=seq_lens.device, + ) + segm_max = torch.empty( + num_decodes, + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=seq_lens.device, + ) + segm_expsum = torch.empty( + num_decodes, + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=seq_lens.device, + ) + else: + segm_output = None + segm_max = None + segm_expsum = None + + if use_split_kv: + assert num_decodes == num_seqs, "3d can only do decodes" + + def call_and_process_output(): + # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) + return unified_attention_grid( + q=query, + k=key_cache, + v=value_cache, + out=output, + cu_seqlens_q=start_loc, + max_seqlen_q=max_query_len, + seqused_k=seq_lens, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0, + q_descale=None, + k_descale=None, # TODO? + v_descale=None, # TODO? + alibi_slopes=None, + use_split_kv=use_split_kv, + num_decodes=num_decodes, + segm_output=segm_output, + segm_max=segm_max, + segm_expsum=segm_expsum, + BLOCK_M_PREFILL=BLOCK_M_PREFILL, + BLOCK_Q_PREFILL=BLOCK_Q_PREFILL, + BLOCK_M_DECODE=BLOCK_M_DECODE, + BLOCK_Q_DECODE=BLOCK_Q_DECODE, + num_q_blocks=num_q_blocks, + block_q_seq_boundaries=block_q_seq_boundaries + ) + + return call_and_process_output + + @staticmethod + def requires_allocated_output() -> bool: + return True + + +class GridTriton2dAttentionCaller(GridTriton3dAttentionCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=2, + ): + + return GridTriton3dAttentionCaller.make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + force_selection=2, + ) + diff --git a/scripts/callers/unified_triton.py b/scripts/callers/unified_triton.py index c1f9b2c81..450465ae2 100644 --- a/scripts/callers/unified_triton.py +++ b/scripts/callers/unified_triton.py @@ -16,8 +16,9 @@ # import torch +import triton -from ibm_triton_lib.kernels import unified_attention +from ibm_triton_lib.kernels import unified_attention, unified_attention_simple, unified_attention_tuned from .base import PrefixPrefillCaller @@ -76,8 +77,8 @@ def call_and_process_output(): k_descale=None, # TODO? v_descale=None, # TODO? alibi_slopes=None, - avg_seqlen_q=avg_seqlen_q, - avg_seqlen_k=avg_seqlen_k, + # avg_seqlen_q=avg_seqlen_q, + # avg_seqlen_k=avg_seqlen_k, force_selection=force_selection, ) @@ -126,6 +127,73 @@ def make_call_func( ) +class SimpleUnifiedTriton2dAttentionCaller(PrefixPrefillCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=2, + ): + """ + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_blocks, block_size, num_kv_heads, head_size] + v_cache = [num_blocks, block_size, num_kv_heads, head_size] + Returns: + shape = [num_tokens, num_heads, head_size] + """ + assert force_selection == 2, "simple unified kernel is only applicable to 2d" + + max_query_len = query_lens.max() + max_seqlen = seq_lens.max() + + avg_seqlen_q = query_lens.to(torch.float).mean() + avg_seqlen_k = seq_lens.to(torch.float).mean() + + def call_and_process_output(): + # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) + return unified_attention_simple( + q=query, + k=key_cache, + v=value_cache, + out=output, + cu_seqlens_q=start_loc, + max_seqlen_q=max_query_len, + seqused_k=seq_lens, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0, + q_descale=None, + k_descale=None, # TODO? + v_descale=None, # TODO? + alibi_slopes=None, + # avg_seqlen_q=avg_seqlen_q, + # avg_seqlen_k=avg_seqlen_k, + ) + + return call_and_process_output + + @staticmethod + def requires_allocated_output() -> bool: + return True + + class UnifiedTritonAutoAttentionCaller(UnifiedTriton3dAttentionCaller): @staticmethod def make_call_func( @@ -162,3 +230,80 @@ def make_call_func( softmax_scale, force_selection=None, ) # none triggers vllm default behaviour + + +class TunedUnifiedTriton2dAttentionCaller(PrefixPrefillCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=2, + ): + """ + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_blocks, block_size, num_kv_heads, head_size] + v_cache = [num_blocks, block_size, num_kv_heads, head_size] + Returns: + shape = [num_tokens, num_heads, head_size] + """ + assert force_selection == 2, "simple unified kernel is only applicable to 2d" + + max_query_len = query_lens.max() + max_seqlen = seq_lens.max() + + avg_seqlen_q = query_lens.to(torch.float).mean() + avg_seqlen_k = seq_lens.to(torch.float).mean() + + MAX_SEQ_Q = triton.next_power_of_2(int(max_query_len)) + MAX_SEQ_K = triton.next_power_of_2(int(max_seqlen)) + AVG_SEQ_Q = triton.next_power_of_2(int(avg_seqlen_q)) + AVG_SEQ_K = triton.next_power_of_2(int(avg_seqlen_k)) + + def call_and_process_output(): + # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) + return unified_attention_tuned( + q=query, + k=key_cache, + v=value_cache, + out=output, + cu_seqlens_q=start_loc, + max_seqlen_q=max_query_len, + seqused_k=seq_lens, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0, + q_descale=None, + k_descale=None, # TODO? + v_descale=None, # TODO? + alibi_slopes=None, + avg_seqlen_q=avg_seqlen_q, + avg_seqlen_k=avg_seqlen_k, + MAX_SEQ_Q=MAX_SEQ_Q, + MAX_SEQ_K=MAX_SEQ_K, + AVG_SEQ_Q=AVG_SEQ_Q, + AVG_SEQ_K=AVG_SEQ_K, + force_selection=2, + ) + + return call_and_process_output + + @staticmethod + def requires_allocated_output() -> bool: + return True diff --git a/scripts/callers/unified_triton_newtiles.py b/scripts/callers/unified_triton_newtiles.py new file mode 100644 index 000000000..d3d751a63 --- /dev/null +++ b/scripts/callers/unified_triton_newtiles.py @@ -0,0 +1,164 @@ +# /******************************************************************************* +# * Copyright 2025 IBM Corporation +# * +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# *******************************************************************************/ +# + +import torch + +from ibm_triton_lib.kernels import unified_attention_newtiles +from .base import PrefixPrefillCaller + + +class NewTilesUnifiedTriton3dAttentionCaller(PrefixPrefillCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=3, + ): + """ + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_blocks, block_size, num_kv_heads, head_size] + v_cache = [num_blocks, block_size, num_kv_heads, head_size] + Returns: + shape = [num_tokens, num_heads, head_size] + """ + + max_query_len = query_lens.max() + max_seqlen = seq_lens.max() + + avg_seqlen_q = query_lens.to(torch.float).mean() + avg_seqlen_k = seq_lens.to(torch.float).mean() + + def call_and_process_output(): + # k must have shape (num_blocks, page_block_size, num_heads_k, head_size) + return unified_attention_newtiles( + q=query, + k=key_cache, + v=value_cache, + out=output, + cu_seqlens_q=start_loc, + max_seqlen_q=max_query_len, + seqused_k=seq_lens, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0, + q_descale=None, + k_descale=None, # TODO? + v_descale=None, # TODO? + alibi_slopes=None, + # avg_seqlen_q=avg_seqlen_q, + # avg_seqlen_k=avg_seqlen_k, + force_selection=force_selection, + ) + + return call_and_process_output + + @staticmethod + def requires_allocated_output() -> bool: + return True + + +class NewTilesUnifiedTriton2dAttentionCaller(NewTilesUnifiedTriton3dAttentionCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=2, + ): + + return NewTilesUnifiedTriton3dAttentionCaller.make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + force_selection=2, + ) + + +class NewTilesUnifiedTritonAutoAttentionCaller(NewTilesUnifiedTriton3dAttentionCaller): + @staticmethod + def make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + # kv_cache_dtype, # unused + force_selection=None, + ): + + return NewTilesUnifiedTriton3dAttentionCaller.make_call_func( + output, + query, + key_cache, + value_cache, + key, + value, + block_tables, + seq_lens, + ctx_lens, + query_lens, + start_loc, + seq_start_loc, + softmax_scale, + force_selection=None, + ) # none triggers vllm default behaviour diff --git a/scripts/dejavu-to-moe_configs.py b/scripts/dejavu-to-moe_configs.py new file mode 100644 index 000000000..93ba227ae --- /dev/null +++ b/scripts/dejavu-to-moe_configs.py @@ -0,0 +1,125 @@ + +import sys +import os +import json + +# __vllm_base_path__ = '/home/zrlngl/watsonx/vllm/vllm/model_executor/layers/fused_moe/configs/' +# __vllm_base_path__ = '/home/zrlngl/watsonx/vllm/ngl_configs/' +__vllm_base_path__ = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../ngl_configs/')) + +moe_keys = [ + 'N', + 'K', + 'E', + # 'EM', + 'num_valid_tokens', + 'num_actual_tokens', + 'stride_am', + 'stride_ak', + 'stride_be', + 'stride_bk', + 'stride_bn', + 'stride_cm', + 'stride_cn', + 'MUL_ROUTED_WEIGHT', + 'top_k', + 'compute_type', + 'use_fp8_w8a8', + 'use_int8_w8a16', + ] +__skip_config_args__ = ["enable_persistent", "maxnreg"] + + +def moe_key_to_param_dict(k): + kl = k[1:-1].split(', ') + ret = {} + for i, label in enumerate(moe_keys): + ret[label] = kl[i] + return ret + + +def create_config_dict(v): + # for vLLM specific + vlist = v.split(", ") + ret = {} + for e in vlist: + sl = e.split(": ") + if sl[0] in __skip_config_args__: + continue + ret[sl[0]] = int(sl[1]) + return ret + + +def translate_dejavu_cache(cache_path): + print(f"Exporting {cache_path} to {__vllm_base_path__}...") + # tag_path = os.path.dirname(cache_path) + # gpu_name_path = os.path.dirname(os.path.dirname(tag_path[:-2])[:-2]) + # gpu_name = os.path.basename(gpu_name_path)[4:] + # adapt to new structure + path_ids = os.path.abspath(cache_path).split('/') + gpu_name_path = path_ids[-7] + gpu_name = gpu_name_path[4:] + + with open(cache_path, 'r') as f: + dejavu_cache = json.load(f) + + cache_dict = dejavu_cache['cache'] + + # k0 = list(cache_dict.keys())[0] + # v0 = cache_dict[k0] + num_experts = None + + config_per_device = {} + timings_per_device = {} + for k, v in cache_dict.items(): + kd = moe_key_to_param_dict(k) + vd = create_config_dict(v) + ot = dejavu_cache['timings'][k]['values'][dejavu_cache['timings'][k]['lables'].index('ms')] + if num_experts is None: + num_experts = int(kd['E'][1:-1]) + else: + assert num_experts == int(kd['E'][1:-1]) + # num_tokens = int(kd['num_valid_tokens'][1:-1]) + # TODO: how to automatically determine /2? update method signature? + # num_tokens = int(int(kd['num_valid_tokens'][1:-1]) / 2) + num_tokens = int(kd['num_actual_tokens'][1:-1]) + # N = int(kd['N'][1:-1])/num_tokens + # N = int(kd['N'][1:-1]) + # vllm_N = int(kd['stride_am'][1:-1]) + vllm_N = int(kd['K'][1:-1]) + # N = int(kd['stride_am'][1:-1])/2 # due to test script shape generation? + new_dict = {num_tokens: vd} + if vllm_N not in config_per_device: + config_per_device[vllm_N] = new_dict + timings_per_device[vllm_N] = {num_tokens: ot} + else: + # config_per_device[vllm_N].update(new_dict) + if num_tokens not in config_per_device[vllm_N]: + config_per_device[vllm_N][num_tokens] = vd + timings_per_device[vllm_N][num_tokens] = ot + else: + if ot >= timings_per_device[vllm_N][num_tokens]: + print(f"configuration for {num_tokens} already existent: {config_per_device[vllm_N][num_tokens]}; " + f"would overwrite with {vd} but is SLOWER, skipping...") + else: + print(f"configuration for {num_tokens} already existent: {config_per_device[vllm_N][num_tokens]}; " + f"overwrite with {vd} because it is FASTER...") + config_per_device[vllm_N][num_tokens] = vd + timings_per_device[vllm_N][num_tokens] = ot + + modified_paths = [] + for N, config_dict in config_per_device.items(): + file_name = f"E={int(num_experts)},N={int(N)},device_name={gpu_name}.json" + modified_paths.append(file_name) + target_path = os.path.abspath(f"{__vllm_base_path__}/{file_name}") + # num_tokens / M as key in dict + with open(target_path, 'w') as f: + json.dump(config_dict, f, indent=4) + + print(f"modified the following files: {modified_paths}") + print(f"triton-dejavu has saved {dejavu_cache['total_bench_time_s']}s") + print('...done') + + +if __name__ == '__main__': + translate_dejavu_cache(sys.argv[1]) diff --git a/scripts/offline_inference.py b/scripts/offline_inference_llama.py similarity index 94% rename from scripts/offline_inference.py rename to scripts/offline_inference_llama.py index d29412d70..3c213e6c9 100644 --- a/scripts/offline_inference.py +++ b/scripts/offline_inference_llama.py @@ -41,8 +41,8 @@ from vllm.distributed import cleanup_dist_env_and_memory llm = LLM( - # model="/mnt/nvme5n1p1/zrlngl/fmaas/models/llama3.1-8b-instruct/", - model="/net/storage149/autofs/css22/nmg/models/hf/meta-llama/Llama-3.1-8B-Instruct/main/", + # model="meta-llama/Llama-3.1-8B-Instruct", + model=f"{os.environ["MY_MODEL_PATH"]}", # max_model_len=2048, # enforce_eager=True, enable_prefix_caching=False, diff --git a/scripts/offline_inference_mamba.py b/scripts/offline_inference_mamba.py new file mode 100644 index 000000000..3a09ff639 --- /dev/null +++ b/scripts/offline_inference_mamba.py @@ -0,0 +1,97 @@ +# /******************************************************************************* +# * Copyright 2025 IBM Corporation +# * +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# *******************************************************************************/ +# + + +import os +import time + +# to enable debug printing +# os.environ["TRITON_BACKEND_DEBUG"] = "1" + +# to use triton_attn backend +os.environ["VLLM_USE_V1"] = "1" +os.environ["VLLM_PLUGINS"] = "" +# os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1" +os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" +# os.environ["VLLM_TRITON_ENABLE_JITCACHE"] = "1" +os.environ["VLLM_TRITON_ENABLE_JITCACHE"] = "0" + +# enable torch profiler, can also be set on cmd line +enable_profiling = True +# enable_profiling = False + +if enable_profiling: + os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_torch_profile_mamba" + + +if __name__ == "__main__": + from vllm import LLM, SamplingParams + from vllm.distributed import cleanup_dist_env_and_memory + + llm = LLM( + model=f"{os.environ["MY_MODEL_PATH"]}", + enforce_eager=True, + enable_chunked_prefill=True, + enable_prefix_caching=False, + tensor_parallel_size=2, + max_model_len=31628, + max_num_seqs=512, + num_scheduler_steps=1, + ) + + # batch_size = 32 + max_tokens = 20 + + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + # ignore_eos=True) + + prompts = [ + "Zurich is a beautiful city with", + "San Francisco is a large city with", + # "Provide a list of instructions for preparing chicken soup for a family " + # "of four.", + # "Skating and cross country skiing technique differ in", + ] + + print( + f"SETUP: vllm backend: {os.environ['VLLM_ATTENTION_BACKEND']} " + f" JITCache: {os.environ['VLLM_TRITON_ENABLE_JITCACHE']} " + ) + print(f"Inference with {len(prompts)} prompts...") + if enable_profiling: + llm.start_profile() + t0 = time.time() + # outputs = llm.generate(prompts, sampling_params) + outputs = [] + for prompt in prompts: + outputs.append(llm.generate(prompt, sampling_params)) + + if enable_profiling: + llm.stop_profile() + t1 = time.time() + + print(f"inference time: {t1-t0:.5f}s") + + for output in outputs: + output = output[0] # in case of loop above + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Add a buffer to wait for profiler in the background process + # (in case MP is on) to finish writing profiling output. + time.sleep(10) diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 000000000..840e5980b --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,2 @@ +llnl-hatchet==2025.1.0 +pytest==8.4.1 diff --git a/scripts/setups/granite4_moe_0.conf b/scripts/setups/granite4_moe_0.conf new file mode 100644 index 000000000..0792807f2 --- /dev/null +++ b/scripts/setups/granite4_moe_0.conf @@ -0,0 +1,24 @@ +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +# BATCH_SIZES = [4] +SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] + +# MOE_N = [768] # intermediate size, g4s +MOE_N = [512] # g4t +MOE_K = [4096] # hidden size +MOE_TOP_K = [10] # num_experts_per_tok +# MOE_NUM_EXPERTS = [72] #g4s +MOE_NUM_EXPERTS = [62] #g4t +TP_FACTOR = [1, 2] +# DTYPES = ["bfloat16"] +DTYPES = ["float16"] + +# BENCHMARK_MODES = ["CUDA_EVENTS"] +BENCHMARK_MODES = ["CUDA_GRAPHS"] + +IMPLEMENTATION_UT = ["TRITON_TUNED", "TRITON_FALLBACK"] +# IMPLEMENTATION_UT = ["TRITON_FALLBACK"] + +# TRITON_BACKEND_DEBUG = 1 +# STORE_TEST_RESULT_PATH=/results + +TEST_ALLOW_INCORRECT = 1 diff --git a/scripts/setups/prefix_grid.conf b/scripts/setups/prefix_grid.conf new file mode 100644 index 000000000..bdde70751 --- /dev/null +++ b/scripts/setups/prefix_grid.conf @@ -0,0 +1,37 @@ +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +# BATCH_SIZES = [4] +# order: num_query_heads, num_kv_heads +NUM_HEADS = [[32, 8]] + +SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] +# SEQUENCE_LENGTHS = [64] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] +PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0] +PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE"] + +# RESERVE_INPUT_TOKEN_LENGTH = ["none", 132096] +# RESERVE_INPUT_TOKEN_LENGTH = [132096] +RESERVE_INPUT_TOKEN_LENGTH = ["none"] + +HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 +BLOCK_SIZES = [16] +NUM_BLOCKS = [4321] # "arbitrary values for testing..." + +PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] +# PROMPT_PATTERNS = [[1.0]] + +MAX_VALUES = [1.0] +BENCHMARK_MODES = ["CUDA_EVENTS"] +# BENCHMARK_MODES = ["CUDA_GRAPHS"] + +# IMPLEMENTATION_UT = ["UNF_TRITON_2D_TUNED", "UNF_TRITON_2D_SIMPLE"] +# IMPLEMENTATION_UT = ["GRID_TRITON_2D", "GRID_TRITON_3D"] +# IMPLEMENTATION_UT = ["FLASH_ATTN"] +IMPLEMENTATION_UT = ["UNF_TRITON_2D", "UNF_TRITON_3D"] + +# TRITON_BACKEND_DEBUG = 1 +# STORE_TEST_RESULT_PATH=/results +STORE_TEST_RESULT_PATH=./zrl-triton-results-and-notebooks/micro_benchmarks/raw_data/ + +TEST_ALLOW_INCORRECT = 1 diff --git a/scripts/setups/prefix_optimize_launchgrid.conf b/scripts/setups/prefix_optimize_launchgrid.conf new file mode 100644 index 000000000..7849b11e7 --- /dev/null +++ b/scripts/setups/prefix_optimize_launchgrid.conf @@ -0,0 +1,43 @@ +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +# BATCH_SIZES = [4] +# order: num_query_heads, num_kv_heads +NUM_HEADS = [[32, 8]] + +SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] +# SEQUENCE_LENGTHS = [64] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5] +# PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] +PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] +# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] +# PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] +PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE"] +# PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE", "ALTERNATING"] + +# max model length granite4, 'none' means not to reserve more than in the batch +# RESERVE_INPUT_TOKEN_LENGTH = ["none", 132096] +# RESERVE_INPUT_TOKEN_LENGTH = [132096] +RESERVE_INPUT_TOKEN_LENGTH = ["none"] + +HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 +# head_size * head_numbers = hidden_size + +BLOCK_SIZES = [16] +NUM_BLOCKS = [4321] # "arbitrary values for testing..." + +PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] +# PROMPT_PATTERNS = [[1.0]] + +MAX_VALUES = [1.0] +BENCHMARK_MODES = ["CUDA_EVENTS"] +# BENCHMARK_MODES = ["CUDA_GRAPHS"] + +# IMPLEMENTATION_UT = ["NT_UNF_TRITON_2D", "NT_UNF_TRITON_3D", "UNF_TRITON_2D", "UNF_TRITON_3D"] +IMPLEMENTATION_UT = ["NT_UNF_TRITON_2D", "NT_UNF_TRITON_3D"] +# IMPLEMENTATION_UT = ["UNF_TRITON_3D"] + +# TRITON_BACKEND_DEBUG = 1 +# STORE_TEST_RESULT_PATH=/results +STORE_TEST_RESULT_PATH=./zrl-triton-results-and-notebooks/micro_benchmarks/raw_data/ + +# TEST_ALLOW_INCORRECT = 1 diff --git a/scripts/setups/prefix_tune_2d.conf b/scripts/setups/prefix_tune_2d.conf index 5987fd28e..1cd5a119b 100644 --- a/scripts/setups/prefix_tune_2d.conf +++ b/scripts/setups/prefix_tune_2d.conf @@ -5,13 +5,14 @@ NUM_HEADS = [[32, 8]] SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] # SEQUENCE_LENGTHS = [64] -# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] -PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5] # PREFIX_PREFILL_SHARE_OF_DECODE = [0.5] PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] # PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.5] # PREFIX_PREFILL_BATCH_COMPOSITION = ["ALTERNATING"] -PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE"] +# PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE"] +PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE", "ALTERNATING"] HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 # head_size * head_numbers = hidden_size @@ -23,12 +24,18 @@ PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] # PROMPT_PATTERNS = [[1.0]] MAX_VALUES = [1.0] -BENCHMARK_MODES = ["CUDA_EVENTS"] +# BENCHMARK_MODES = ["CUDA_EVENTS"] +BENCHMARK_MODES = ["CUDA_GRAPHS"] -IMPLEMENTATION_UT = ["UNF_TRITON_2D"] +# IMPLEMENTATION_UT = ["UNF_TRITON_2D"] +# IMPLEMENTATION_UT = ["UNF_TRITON_2D_SIMPLE"] # IMPLEMENTATION_UT = ["FLASH_ATTN", "UNF_TRITON_2D"] +# IMPLEMENTATION_UT = ["NT_UNF_TRITON_2D", "NT_UNF_TRITON_3D", "FLASH_ATTN", "UNF_TRITON_2D", "UNF_TRITON_3D"] +IMPLEMENTATION_UT = ["NT_UNF_TRITON_2D", "NT_UNF_TRITON_3D", "UNF_TRITON_2D", "UNF_TRITON_3D"] +# IMPLEMENTATION_UT = ["UNF_TRITON_3D"] # TRITON_BACKEND_DEBUG = 1 # STORE_TEST_RESULT_PATH=/results +STORE_TEST_RESULT_PATH=./zrl-triton-results-and-notebooks/micro_benchmarks/raw_data/ # TEST_ALLOW_INCORRECT = 1 diff --git a/scripts/setups/tune_2d_ws.conf b/scripts/setups/tune_2d_ws.conf new file mode 100644 index 000000000..f2d435a6a --- /dev/null +++ b/scripts/setups/tune_2d_ws.conf @@ -0,0 +1,32 @@ +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +# BATCH_SIZES = [4] +# order: num_query_heads, num_kv_heads +NUM_HEADS = [[32, 8]] + +SEQUENCE_LENGTHS = [16, 32, 64, 128, 512, 1024, 2048, 4096] +# SEQUENCE_LENGTHS = [64] +# PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5] +PREFIX_PREFILL_SHARE_OF_DECODE = [0.0, 0.5, 1.0] +PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0, 0.5] +# PREFIX_PREFILL_SHARE_OF_PARTIAL_PREFILL = [0.0] +PREFIX_PREFILL_BATCH_COMPOSITION = ["DEC_PRE"] + +HEAD_SIZES = [128] # only powers of 2! for llama2 & 3 +BLOCK_SIZES = [16] +NUM_BLOCKS = [4321] # "arbitrary values for testing..." + +PROMPT_PATTERNS = [[1.0], [0.1, 0.4, 0.5, 1.0, 0.2]] +# PROMPT_PATTERNS = [[1.0]] + +MAX_VALUES = [1.0] +# BENCHMARK_MODES = ["CUDA_EVENTS"] +BENCHMARK_MODES = ["CUDA_GRAPHS"] + +# IMPLEMENTATION_UT = ["UNF_TRITON_2D_TUNED"] +IMPLEMENTATION_UT = ["UNF_TRITON_2D_TUNED", "UNF_TRITON_2D_SIMPLE"] + +# TRITON_BACKEND_DEBUG = 1 +# STORE_TEST_RESULT_PATH=/results +STORE_TEST_RESULT_PATH=./zrl-triton-results-and-notebooks/micro_benchmarks/raw_data/ + +# TEST_ALLOW_INCORRECT = 1 diff --git a/third_party/fmwork b/third_party/fmwork new file mode 160000 index 000000000..2083a4e33 --- /dev/null +++ b/third_party/fmwork @@ -0,0 +1 @@ +Subproject commit 2083a4e3376ba8b6318aba7b8f10b6bfb830b912 diff --git a/third_party/vedantroy_paged_attention.py b/third_party/kernels/vedantroy_paged_attention.py similarity index 100% rename from third_party/vedantroy_paged_attention.py rename to third_party/kernels/vedantroy_paged_attention.py diff --git a/triton-dejavu b/triton-dejavu index c2555ce1a..8f06d4903 160000 --- a/triton-dejavu +++ b/triton-dejavu @@ -1 +1 @@ -Subproject commit c2555ce1a61d2288007366b2dcef1203ed1f26ee +Subproject commit 8f06d4903056e30867620576b251489c3e9baa8c diff --git a/vllm b/vllm index d91278181..f0c503f66 160000 --- a/vllm +++ b/vllm @@ -1 +1 @@ -Subproject commit d91278181d89686b73b2ec88c2db4d55c6c506cb +Subproject commit f0c503f66e2f6aafa966318d488fd92ac662cdf0