diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 51d39a2ae1897a..36797560f356ac 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -62,7 +62,8 @@ void WeightQuantizeKernel(const Context& dev_ctx, scale->data(), weight_shape, arch, - algo); + algo, + group_size); trans(dev_ctx, quanted_x, out, axis); } else if (algo == "weight_only_int8") { dev_ctx.template Alloc(scale); @@ -72,7 +73,8 @@ void WeightQuantizeKernel(const Context& dev_ctx, scale->data(), weight_shape, arch, - algo); + algo, + group_size); #ifdef PADDLE_WITH_HIP std::vector axis = {1, 0}; funcs::Transpose trans; @@ -93,7 +95,8 @@ void WeightQuantizeKernel(const Context& dev_ctx, scale->data(), weight_shape, arch, - algo); + algo, + group_size); #ifdef PADDLE_WITH_HIP DenseTensor x_int_tmp(out->type()); x_int_tmp.Resize({static_cast(m), static_cast(n / 2)}); diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 039c99e9dfa617..bede52adccf2ae 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -425,6 +425,213 @@ __global__ void per_channel_quant_gpu_int4_col_pack(const T* weight_data, } } +template +__global__ void per_group_quant_gpu_int4_col_pack(const T* weight_data, + int8_t* quanted_weight_data, + ScaleT* scale_data, + int total_k, + int total_vec_n, + int group_size) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n < total_vec_n) { + const int4* vec_weight_data_ptr = + reinterpret_cast(weight_data); + + phi::AlignedVector abs_max; + + // Compute per group row + for (int k = 0; k < total_k; k += group_size) { + // Init per group abs_max +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = static_cast(0.0f); + } + for (int g = 0; g < group_size && k + g < total_k; g++) { + int linear_index = (k + g) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = vec_weight_data_ptr[linear_index]; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = fmaxf(abs_max[i], fabsf(weight[i])); + } + } + // Compute Scale + phi::AlignedVector scale; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + scale[i] = static_cast(abs_max[i] / static_cast(7.0f)); + } + *reinterpret_cast( + scale_data + (k / group_size) * (total_vec_n * VectorSize) + + n * VectorSize) = *reinterpret_cast(&scale); + + // group-wise weight quant + for (int g = 0; g < group_size / 2; g++) { + phi::AlignedVector quanted_weight; + // write 2 elements to an int8 + for (int packed_idx = 0; + packed_idx < 2 && k + g * 2 + packed_idx < total_k; + packed_idx++) { + int linear_index = (k + g * 2 + packed_idx) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = *reinterpret_cast( + vec_weight_data_ptr + linear_index); +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + float weight_elt = + (static_cast(weight[i]) / static_cast(scale[i])); + int8_t clipped_weight = static_cast( + lroundf(fmaxf(-7.0f, fminf(7.0f, weight_elt)))); + // Reset the last 4 bit or first 4 bit + quanted_weight[i] &= ~(0x0F << (4 * packed_idx)); + quanted_weight[i] |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + int linear_index = + (k / 2 + g) * total_vec_n * VectorSize + n * VectorSize; + + *reinterpret_cast(quanted_weight_data + linear_index) = + *reinterpret_cast(&quanted_weight); + } + } + } +} + +template +__global__ void per_group_quant_gpu_int4_row_pack(const T* weight_data, + int8_t* quanted_weight_data, + ScaleT* scale_data, + int total_k, + int total_vec_n, + int group_size) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n < total_vec_n) { + const int4* vec_weight_data_ptr = + reinterpret_cast(weight_data); + + phi::AlignedVector abs_max; + + // Compute per group row + for (int k = 0; k < total_k; k += group_size) { + // Init per group abs_max +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = static_cast(0.0f); + } + for (int g = 0; g < group_size && k + g < total_k; g++) { + int linear_index = (k + g) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = vec_weight_data_ptr[linear_index]; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = fmaxf(abs_max[i], fabsf(weight[i])); + } + } + // Compute Scale + phi::AlignedVector scale; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + scale[i] = static_cast(abs_max[i] / static_cast(7.0f)); + } + *reinterpret_cast( + scale_data + (k / group_size) * (total_vec_n * VectorSize) + + n * VectorSize) = *reinterpret_cast(&scale); + + // group-wise weight quant + for (int g = 0; g < group_size && k + g < total_k; g++) { + int linear_index = (k + g) * total_vec_n + n; + phi::AlignedVector weight; + phi::AlignedVector quanted_weight; + *reinterpret_cast(&weight) = + *reinterpret_cast(vec_weight_data_ptr + linear_index); +#pragma unroll + for (int i = 0; i < VectorSize / 2; i++) { + int8_t packed_int4s = 0; + for (int pack = 0; pack < 2; pack++) { + int vector_index = i * 2 + pack; + const float weight_elt = static_cast(weight[vector_index]) / + static_cast(scale[vector_index]); + int8_t clipped_weight = static_cast( + lroundf(fmaxf(-7.0f, fminf(7.0f, weight_elt)))); + packed_int4s |= ((clipped_weight & 0x0F) << (4 * pack)); + } + quanted_weight[i] = packed_int4s; + } + int quant_weight_idx = + (k + g) * total_vec_n * VectorSize / 2 + n * VectorSize / 2; + *reinterpret_cast(quanted_weight_data + quant_weight_idx) = + *reinterpret_cast(&quanted_weight); + } + } + } +} + +template +__global__ void group_wise_quant_gpu(const T* weight_data, + int8_t* quanted_weight_data, + ScaleT* scale_data, + int total_k, + int total_vec_n, + int group_size) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + // This can be optimize with group-wize parallel + if (n < total_vec_n) { + const int4* vec_weight_data_ptr = + reinterpret_cast(weight_data); + int2* vec_quanted_weight_data = + reinterpret_cast(quanted_weight_data); + + phi::AlignedVector abs_max; + + // Compute per group row + for (int k = 0; k < total_k; k += group_size) { + // Init per group abs_max +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = static_cast(0.0f); + } + for (int g = 0; g < group_size && k + g < total_k; g++) { + int linear_index = (k + g) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = vec_weight_data_ptr[linear_index]; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = fmaxf(abs_max[i], fabsf(weight[i])); + } + } + // Compute Scale + phi::AlignedVector scale; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + scale[i] = static_cast(abs_max[i] / static_cast(127.0f)); + } + *reinterpret_cast( + scale_data + (k / group_size) * (total_vec_n * VectorSize) + + n * VectorSize) = *reinterpret_cast(&scale); + + // group-wise weight quant + for (int g = 0; g < group_size && k + g < total_k; g++) { + phi::AlignedVector quanted_weight; + int linear_index = (k + g) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = + *reinterpret_cast(vec_weight_data_ptr + linear_index); +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + float scaled_weight = + (static_cast(weight[i]) / static_cast(abs_max[i])) * + static_cast(127.0f); + int8_t clipped_weight = static_cast( + lroundf(fmaxf(-127.0f, fminf(127.0f, scaled_weight)))); + quanted_weight[i] = clipped_weight; + } + *reinterpret_cast(vec_quanted_weight_data + linear_index) = + *reinterpret_cast(&quanted_weight); + } + } + } +} + template void weight_quant_gpu(const GPUContext& dev_ctx, const T* weight_data, @@ -432,7 +639,8 @@ void weight_quant_gpu(const GPUContext& dev_ctx, ScaleT* scale_data, const std::vector& shape, const int32_t arch, - const std::string& algo) { + const std::string& algo, + const int32_t group_size) { int total_k = shape[0]; int total_n = shape[1]; int numel = total_k * total_n; @@ -457,24 +665,54 @@ void weight_quant_gpu(const GPUContext& dev_ctx, #else if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || (arch == 75)) { - per_channel_quant_gpu_int4_col_pack - <<>>(weight_data, - quanted_weight_data, - scale_data, - total_k, - vec_total_n); + if (group_size == -1) { // per channel + per_channel_quant_gpu_int4_col_pack + <<>>(weight_data, + quanted_weight_data, + scale_data, + total_k, + vec_total_n); + } else { + per_group_quant_gpu_int4_col_pack + <<>>(weight_data, + quanted_weight_data, + scale_data, + total_k, + vec_total_n, + group_size); + } } else if ((arch == 70)) { - per_channel_quant_gpu_int4_row_pack + if (group_size == -1) { + per_channel_quant_gpu_int4_row_pack + <<>>(weight_data, + quanted_weight_data, + scale_data, + total_k, + vec_total_n); + } else { + per_group_quant_gpu_int4_row_pack + <<>>(weight_data, + quanted_weight_data, + scale_data, + total_k, + vec_total_n, + group_size); + } + } +#endif + } else { + if (group_size == -1) { // per channel + per_channel_quant_gpu<<>>( + weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); + } else { + group_wise_quant_gpu <<>>(weight_data, quanted_weight_data, scale_data, total_k, - vec_total_n); + vec_total_n, + group_size); } -#endif - } else { - per_channel_quant_gpu<<>>( - weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); } } diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py index 164ad49ae7e0f9..9e88e1ba281978 100644 --- a/test/quantization/test_weight_only_linear.py +++ b/test/quantization/test_weight_only_linear.py @@ -730,7 +730,7 @@ def test_weight_only_linear(self): not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) -class WeightOnlyQuantizeCPUGPUTestCase(unittest.TestCase): +class WeightOnlyQuantizeCPUGPUTestCase1(unittest.TestCase): def config(self): self.dtype = 'float16' self.batch = 1 @@ -738,18 +738,19 @@ def config(self): self.in_features = 64 self.out_features = 256 self.group_size = -1 + self.algo = "weight_only_int4" def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): for arch in [70, 75, 80, 86]: weight_gpu, weight_scale_gpu = Q.weight_quantize( weight_float.cuda(), - algo="weight_only_int4", + algo=self.algo, arch=arch, group_size=self.group_size, ) weight_cpu, weight_scale_cpu = Q.weight_quantize( weight_float.cpu(), - algo="weight_only_int4", + algo=self.algo, arch=arch, group_size=self.group_size, ) @@ -787,6 +788,94 @@ def setUp(self): self.weightQuantizeCPUGPUConsistenceCheck(self.float_weight) +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase2(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 64 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase3(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase4(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 64 + self.dtype = 'bfloat16' + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase5(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 128 + self.dtype = 'bfloat16' + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase6(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 64 + self.algo = "weight_only_int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase7(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 128 + self.algo = "weight_only_int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase8(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 64 + self.dtype = 'bfloat16' + self.algo = "weight_only_int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyQuantizeCPUGPUTestCase9(WeightOnlyQuantizeCPUGPUTestCase1): + def config(self): + super().config() + self.group_size = 128 + self.dtype = 'bfloat16' + self.algo = "weight_only_int8" + + @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020