Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ def triton_scale_nvfp4_quant(
stochastic_casting (bool): Whether to use stochastic casting.

Returns:
torch.Tensor: [M / 2] nvfp4 scaled tensor packed into in8
torch.Tensor: [M / 2] nvfp4 scaled tensor packed into int8
torch.Tensor: [M / group_size] nvfp4 shared exponents into int8

eg.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>

// clang-format off
// The fixed ordering of the headers is required for CUTLASS 3.2+
#include <cute/tensor.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp> // @manual
#include <cutlass/gemm/device/gemm_universal_adapter.h> // @manual
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
// clang-format on

#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
#include "f4f4bf16_grouped/f4f4bf16_grouped_manifest.cuh"
Expand Down Expand Up @@ -160,14 +150,10 @@ at::Tensor dispatch_fp4_grouped_kernel(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
std::optional<at::Tensor> M_sizes = std::nullopt,
std::optional<at::Tensor> global_scale = std::nullopt,
std::optional<at::Tensor> starting_row_after_padding = std::nullopt,
bool use_mx = true) {
TORCH_CHECK(
zero_start_index_M.has_value() != M_sizes.has_value(),
"One of zero_start_index_M or M_sizes must be provided.");
TORCH_CHECK(M_sizes.has_value(), "M_sizes is assumed to be provided.");
TORCH_CHECK(
starting_row_after_padding.has_value(),
Expand All @@ -187,8 +173,6 @@ at::Tensor dispatch_fp4_grouped_kernel(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down Expand Up @@ -228,7 +212,6 @@ at::Tensor f4f4bf16_grouped_stacked(
x_scale,
w_scale,
Y,
std::nullopt,
M_sizes,
global_scale,
starting_row_after_padding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_f(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_f(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_t(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_t(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_f(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_f(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_t(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_128_64_256_1_1_1_t(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_f(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_f(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_t(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_128_256_2_1_1_t(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_f(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_f(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_t(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_256_256_2_1_1_t(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_64_256_2_1_1_f(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_64_256_2_1_1_f(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ at::Tensor f4f4bf16_grouped_256_64_256_2_1_1_t(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
int64_t G,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes,
std::optional<at::Tensor> global_scale,
std::optional<at::Tensor> starting_row_after_padding) {
Expand All @@ -36,8 +34,6 @@ at::Tensor f4f4bf16_grouped_256_64_256_2_1_1_t(
x_scale,
w_scale,
output,
G,
zero_start_index_M,
M_sizes,
global_scale,
starting_row_after_padding);
Expand Down
Loading
Loading