diff --git a/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 0acff733c0d9..858b3f490fc9 100644 --- a/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -35,13 +35,11 @@ #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - #include "cutlass/gemm/kernel/gemm_transpose_operands.h" #include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" #include "cutlass/trace.h" - #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -58,58 +56,22 @@ template using void_t = void; template -struct use_dq_gemm : platform::false_type {}; +struct use_dq_gemm : platform::false_type { + using LayoutScaleZero = void; +}; template struct use_dq_gemm> - : platform::true_type {}; - -// SFINAE overload for dequantizing gemm -template < - typename Mma, - typename ElementScale, - typename platform::enable_if::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, - int gemm_k_iterations, - typename Mma::FragmentC& accum, // NOLINT - typename Mma::IteratorA iterator_A, - typename Mma::IteratorB iterator_B, - typename Mma::FragmentC const& src_accum, - ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, - const int thread_idx, - MatrixCoord tb_offset_scale) { - typename Mma::IteratorScale iterator_scale( - Mma::IteratorScale::Layout(scale_extent.column()), - weight_scale_ptr, - scale_extent, - thread_idx, - tb_offset_scale); - - mma(gemm_k_iterations, - accum, - iterator_A, - iterator_B, - iterator_scale, - src_accum); -} + : platform::true_type { + using LayoutScaleZero = typename Mma::IteratorScale::Layout; +}; -// SFINAE overload for normal gemm. This completely ignores the scale parameters -template < - typename Mma, - typename ElementScale, - typename platform::enable_if::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, - int gemm_k_iterations, - typename Mma::FragmentC& accum, // NOLINT - typename Mma::IteratorA iterator_A, - typename Mma::IteratorB iterator_B, - typename Mma::FragmentC const& src_accum, - ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, - const int thread_idx, - MatrixCoord tb_offset_scale) { - mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); +template +CUTLASS_HOST_DEVICE bool tensor_aligned(Element const* ref, + int stride, + int alignment) { + return (reinterpret_cast(ref) % alignment == 0) && + (stride % alignment == 0); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -120,11 +82,13 @@ template struct MoeFCGemm { - public: +public: using Mma = Mma_; using Epilogue = Epilogue_; using EpilogueOutputOp = typename Epilogue::OutputOp; @@ -195,6 +159,7 @@ struct MoeFCGemm { int problem_count; int threadblock_count; + int group_size; typename EpilogueOutputOp::Params output_op; @@ -220,6 +185,7 @@ struct MoeFCGemm { Arguments() : problem_count(0), threadblock_count(0), + group_size(-1), ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), @@ -243,10 +209,12 @@ struct MoeFCGemm { int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + int group_size, GemmCoord* host_problem_sizes = nullptr) : problem_count(problem_count), threadblock_count(threadblock_count), output_op(output_op), + group_size(group_size), ptr_A(const_cast(ptr_A)), ptr_B(const_cast(ptr_B)), weight_scales(const_cast(weight_scales)), @@ -271,6 +239,7 @@ struct MoeFCGemm { struct Params { typename ProblemVisitor::Params problem_visitor; int threadblock_count; + int group_size; typename EpilogueOutputOp::Params output_op; @@ -280,6 +249,7 @@ struct MoeFCGemm { ElementC* ptr_C; ElementC* ptr_D; + // // Methods // @@ -303,6 +273,7 @@ struct MoeFCGemm { workspace, tile_count), threadblock_count(args.threadblock_count), + group_size(args.group_size), output_op(args.output_op), ptr_A(args.ptr_A), ptr_B(args.ptr_B), @@ -338,7 +309,7 @@ struct MoeFCGemm { typename Epilogue::SharedStorage epilogue; }; - public: +public: // // Methods // @@ -350,7 +321,7 @@ struct MoeFCGemm { static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } - + CUTLASS_HOST_DEVICE static Status can_implement(Arguments const& args) { if (platform::is_same::value || platform::is_same::value) { @@ -360,12 +331,68 @@ struct MoeFCGemm { "uint8_t and uint4b_t"); return Status::kInvalid; } + static int const kAlignmentA = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentScale = 128 / sizeof_bits::value; + static int const kAlignmentC = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + if (!tensor_aligned(args.ptr_A, args.gemm_k, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + // TODO: stride is gemm_n or gemm_n / 2 ? + if (!tensor_aligned(args.ptr_B, args.gemm_n, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!tensor_aligned(args.weight_scales, args.gemm_n, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + + if (!tensor_aligned(args.ptr_C, args.gemm_n, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!tensor_aligned(args.ptr_D, args.gemm_n, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (args.weight_scales == nullptr) { + return Status::kErrorNotSupported; + } } else if (args.weight_scales != nullptr) { CUTLASS_TRACE_HOST( "MoeFCGemm::can_implement() - weight scales are ignored for all " "types except uint8_t and uint4b_t"); return Status::kInvalid; } + // Handle the case the input is too short + else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - gemm_n is smaller than the input " + "alignment"); + return Status::kInvalid; + } return Status::kSuccess; } @@ -373,6 +400,52 @@ struct MoeFCGemm { Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { return 0; } + // Initializes the fine grained scale+bias iterator. Needed since the fine + // grained iterator has a different constructor signature than a regular + // cutlass iterator + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size); + }; + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, + pointer_scale, + extent, + thread_id, + threadblock_offset, + group_size); + } + }; + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale( + params, pointer_scale, extent, thread_id, threadblock_offset); + } + }; // The dummy template parameter is not used and exists so that we can compile // this code using a standard earlier than C++17. Prior to C++17, fully @@ -422,7 +495,9 @@ struct MoeFCGemm { (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; // Outer 'persistent' loop to iterate over tiles + int loop = 0; while (problem_visitor.next_tile()) { + loop++; GemmCoord problem_size = problem_visitor.problem_size(); int32_t problem_idx = problem_visitor.problem_index(); int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); @@ -451,7 +526,12 @@ struct MoeFCGemm { platform::is_same::value ? gemm_n : gemm_k * kInterleave; - + ElementScale* ptr_Scale = + use_dq_gemm::value + ? params.weight_scales + + problem_idx * gemm_k / params.group_size * gemm_n + : nullptr; + long ldm_Scale = gemm_n; // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ threadblock_offset.m(), @@ -497,8 +577,11 @@ struct MoeFCGemm { // Construct thread-scoped matrix multiply auto CreateMMA = [&]() { if constexpr (use_dq_gemm::value) - return Mma( - shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); + return Mma(shared_storage.main_loop, + params.group_size, + thread_idx, + warp_idx, + lane_idx); else return Mma( shared_storage.main_loop, thread_idx, warp_idx, lane_idx); @@ -514,27 +597,40 @@ struct MoeFCGemm { __syncthreads(); // Compute threadblock-scoped matrix multiply-add - ElementScale* weight_scale_ptr = - params.weight_scales + problem_idx * problem_size.n(); - run_mma(mma, - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators, - weight_scale_ptr, - {1, problem_size.n()}, - thread_idx, - tb_offset_scale); + if constexpr (use_dq_gemm::value) { + typename MatrixCoord::Index scale_row_extent = + FineGrained == true ? gemm_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = + initialize_scale::apply( + use_dq_gemm::LayoutScaleZero(ldm_Scale), + reinterpret_cast( + ptr_Scale), + {scale_row_extent, problem_size.n()}, + thread_idx, + tb_offset_scale, + params.group_size); + + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_scale, + accumulators); + } else { + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + } // // Epilogue // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = - reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_C = (params.ptr_C == nullptr) + ? nullptr + : reinterpret_cast(params.ptr_C) + + problem_idx * gemm_n; ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; @@ -550,7 +646,8 @@ struct MoeFCGemm { ptr_C, problem_size.mn(), thread_idx, - threadblock_offset.mn()); + threadblock_offset.mn(), + nullptr); // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( @@ -558,13 +655,28 @@ struct MoeFCGemm { ptr_D, problem_size.mn(), thread_idx, - threadblock_offset.mn()); + threadblock_offset.mn(), + nullptr); Epilogue epilogue( shared_storage.epilogue, thread_idx, warp_idx, lane_idx); // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + if constexpr (platform::is_same< + EpilogueOutputOp, + cutlass::epilogue::thread::LinearCombination< + typename EpilogueOutputOp::ElementOutput, + EpilogueOutputOp::kCount, + typename EpilogueOutputOp::ElementAccumulator, + typename EpilogueOutputOp::ElementCompute, + EpilogueOutputOp::kScale, + EpilogueOutputOp::kRound>>::value) { + EpilogueOutputOp output_op(params.output_op, problem_idx); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } else { + EpilogueOutputOp output_op(params.output_op); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } // Next tile problem_visitor.advance(gridDim.x); diff --git a/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h b/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h index 0636400517ea..b2d5f82b5788 100644 --- a/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h +++ b/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h @@ -37,6 +37,7 @@ class MoeGemmRunner { int64_t gemm_k, int num_experts, std::string activation_type, + const int32_t weightonly_group_size, cudaStream_t stream); void moe_gemm(const T* A, @@ -48,10 +49,11 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, cudaStream_t stream); private: - template + template void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, @@ -62,11 +64,12 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); - template + template void run_gemm(const T* A, const WeightType* B, const T* weight_scales, @@ -77,6 +80,7 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, cudaStream_t stream); private: diff --git a/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index 8069db569f5b..7edf9f6b2e59 100644 --- a/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -52,6 +52,7 @@ template @@ -64,6 +65,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, @@ -127,7 +129,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op; - + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = + typename cutlass::arch::TagOperator::TaggedOperator; // Finally, set up the kernel. using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< ElementType, @@ -150,7 +155,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, - typename MixedGemmArchTraits::Operator>::GemmKernel; + TaggedOperator>::GemmKernel; using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + GemmKernel_::kGroupScheduleMode, + FineGrained>; using GemmGrouped = cutlass::gemm::device::GemmGrouped; @@ -177,7 +183,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); - + typename GemmGrouped::Arguments args( num_experts, threadblock_count, @@ -189,7 +195,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, reinterpret_cast(C), total_rows_before_expert, gemm_n, - gemm_k); + gemm_k, + group_size); GemmGrouped gemm; @@ -221,6 +228,7 @@ template struct dispatch_stages { @@ -269,6 +280,7 @@ struct dispatch_stages(A, @@ -288,6 +301,7 @@ struct dispatch_stages @@ -305,6 +320,7 @@ struct dispatch_stages(A, @@ -337,6 +355,7 @@ struct dispatch_stages void dispatch_gemm_config(const T* A, @@ -359,6 +379,7 @@ void dispatch_gemm_config(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, @@ -369,6 +390,7 @@ void dispatch_gemm_config(const T* A, WeightType, \ arch, \ EpilogueTag, \ + FineGrained, \ ThreadblockShape, \ WarpShape, \ STAGE>::dispatch(A, \ @@ -380,6 +402,7 @@ void dispatch_gemm_config(const T* A, gemm_n, \ gemm_k, \ num_experts, \ + group_size, \ gemm_config, \ multi_processor_count, \ stream, \ @@ -406,6 +429,7 @@ void dispatch_gemm_config(const T* A, WeightType, \ arch, \ EpilogueTag, \ + FineGrained, \ cutlass::gemm::GemmShape, \ cutlass::gemm::GemmShape>( \ A, \ @@ -417,6 +441,7 @@ void dispatch_gemm_config(const T* A, gemm_n, \ gemm_k, \ num_experts, \ + group_size, \ gemm_config, \ multi_processor_count, \ stream, \ @@ -429,6 +454,7 @@ template ::value && std::is_same::value>::type* = nullptr> @@ -442,6 +468,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, @@ -474,6 +501,7 @@ template ::value && !std::is_same::value>::type* = nullptr> @@ -487,6 +515,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, @@ -542,6 +571,7 @@ template < typename WeightType, typename arch, typename EpilogueTag, + bool FineGrained, typename std::enable_if::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, @@ -553,6 +583,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, @@ -588,8 +619,8 @@ MoeGemmRunner::MoeGemmRunner() { } template -template -void MoeGemmRunner::dispatch_to_arch( +template +void MoeGemmRunner::dispatch_to_arch( const T* A, const WeightType* B, const T* weight_scales, @@ -600,25 +631,27 @@ void MoeGemmRunner::dispatch_to_arch( int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) { -#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ - dispatch_moe_gemm_to_cutlass( \ - A, \ - B, \ - weight_scales, \ - biases, \ - C, \ - total_rows_before_expert, \ - total_rows, \ - gemm_n, \ - gemm_k, \ - num_experts, \ - gemm_config, \ - sm_, \ - multi_processor_count_, \ - stream, \ +#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ + dispatch_moe_gemm_to_cutlass( \ + A, \ + B, \ + weight_scales, \ + biases, \ + C, \ + total_rows_before_expert, \ + total_rows, \ + gemm_n, \ + gemm_k, \ + num_experts, \ + group_size, \ + gemm_config, \ + sm_, \ + multi_processor_count_, \ + stream, \ occupancy); if (sm_ >= 70 && sm_ < 75) { @@ -633,8 +666,8 @@ void MoeGemmRunner::dispatch_to_arch( } template -template -void MoeGemmRunner::run_gemm( +template +void MoeGemmRunner::run_gemm( const T* A, const WeightType* B, const T* weight_scales, @@ -645,11 +678,12 @@ void MoeGemmRunner::run_gemm( int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, cudaStream_t stream) { static constexpr bool is_weight_only = !std::is_same::value; static constexpr bool only_simt_configs = std::is_same::value; - std::vector candidate_configs = - get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true); + std::vector candidate_configs = get_candidate_configs( + sm_, group_size, is_weight_only, only_simt_configs, true); static constexpr int warm_time = 5; static constexpr int test_time = 10; auto& gemmConfigManager = GemmConfigManager::Instance(); @@ -672,7 +706,7 @@ void MoeGemmRunner::run_gemm( for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { try { for (int i = 0; i < warm_time; i++) { - dispatch_to_arch(A, + dispatch_to_arch(A, B, weight_scales, biases, @@ -682,6 +716,7 @@ void MoeGemmRunner::run_gemm( gemm_n, gemm_k, num_experts, + group_size, candidate_configs[ii], stream); } @@ -692,7 +727,7 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaStreamSynchronize(stream)); check_cuda_error(cudaEventRecord(start, stream)); for (int i = 0; i < test_time; i++) { - dispatch_to_arch(A, + dispatch_to_arch(A, B, weight_scales, biases, @@ -702,6 +737,7 @@ void MoeGemmRunner::run_gemm( gemm_n, gemm_k, num_experts, + group_size, candidate_configs[ii], stream); } @@ -728,18 +764,19 @@ void MoeGemmRunner::run_gemm( PADDLE_FATAL("[MoE Configure Search] find no one available config."); } } - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - chosen_config, - stream); + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + group_size, + chosen_config, + stream); } template @@ -755,32 +792,73 @@ void MoeGemmRunner::moe_gemm_bias_act( int64_t gemm_k, int num_experts, std::string activation_type, + const int32_t weightonly_group_size, cudaStream_t stream) { if (activation_type == "none") { if (biases) { - run_gemm(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); + if (weightonly_group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + weightonly_group_size, + stream); + } else { + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + weightonly_group_size, + stream); + } } else { - run_gemm(A, - B, - weight_scales, - nullptr, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); + if (weightonly_group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + weightonly_group_size, + stream); + } else { + run_gemm(A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + weightonly_group_size, + stream); + } } } } @@ -795,16 +873,37 @@ void MoeGemmRunner::moe_gemm(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + int group_size, cudaStream_t stream) { - run_gemm(A, - B, - weight_scales, - nullptr, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); + if (group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + group_size, + stream); + } else { + run_gemm(A, + B, + weight_scales, + nullptr, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + group_size, + stream); + } } diff --git a/csrc/gpu/moe/fused_moe/fused_moe.cu b/csrc/gpu/moe/fused_moe/fused_moe.cu index 538572f00e70..a13190c5c48a 100644 --- a/csrc/gpu/moe/fused_moe/fused_moe.cu +++ b/csrc/gpu/moe/fused_moe/fused_moe.cu @@ -57,6 +57,7 @@ void FusedMoeKernel(const paddle::Tensor& input, const paddle::optional& ffn2_scale, const paddle::optional& ffn2_bias, const std::string& quant_method, + const int weightonly_group_size, const int moe_topk, const bool group_moe, const bool norm_topk_prob, @@ -86,6 +87,7 @@ void FusedMoeKernel(const paddle::Tensor& input, ffn2_scale ? ffn2_scale.get_ptr() : nullptr, ffn2_bias ? ffn2_bias.get_ptr() : nullptr, nullptr, + weightonly_group_size, moe_topk, group_moe, norm_topk_prob, @@ -105,6 +107,7 @@ std::vector FusedExpertMoe( const paddle::optional& ffn2_bias, const paddle::optional& ffn2_scale, const std::string& quant_method, + const int weightonly_group_size, const int moe_topk, const bool norm_topk_prob, const bool group_moe) { @@ -122,6 +125,7 @@ std::vector FusedExpertMoe( ffn2_scale, ffn2_bias, quant_method, + weightonly_group_size, moe_topk, group_moe, norm_topk_prob, @@ -137,6 +141,7 @@ std::vector FusedExpertMoe( ffn2_scale, ffn2_bias, quant_method, + weightonly_group_size, moe_topk, group_moe, norm_topk_prob, @@ -184,6 +189,7 @@ PD_BUILD_OP(fused_expert_moe) paddle::Optional("ffn2_scale")}) .Outputs({"output"}) .Attrs({"quant_method:std::string", + "weightonly_group_size:int", "moe_topk:int", "norm_topk_prob:bool", "group_moe:bool"}) diff --git a/csrc/gpu/moe/fused_moe/moe/fused_moe_helper.h b/csrc/gpu/moe/fused_moe/moe/fused_moe_helper.h index 002efd166aa2..c576159ec7da 100644 --- a/csrc/gpu/moe/fused_moe/moe/fused_moe_helper.h +++ b/csrc/gpu/moe/fused_moe/moe/fused_moe_helper.h @@ -126,6 +126,7 @@ class MoeHelper { const paddle::Tensor *ffn2_scale, const paddle::Tensor *ffn2_bias, const paddle::Tensor *moe_token_type_ids, + const int weightonly_group_size, const int moe_topk, const bool group_moe, const bool norm_topk_prob, @@ -304,6 +305,7 @@ class MoeHelper { hidden_size, num_experts, "none", + weightonly_group_size, stream); } else if (gemm_method_ == "weight_only_int4") { int4_moe_gemm_runner_->moe_gemm_bias_act( @@ -319,6 +321,7 @@ class MoeHelper { hidden_size, num_experts, "none", + weightonly_group_size, stream); } else { fp16_moe_gemm_runner_->moe_gemm_bias_act( @@ -333,6 +336,7 @@ class MoeHelper { hidden_size, num_experts, "none", + weightonly_group_size, stream); } @@ -356,6 +360,7 @@ class MoeHelper { hidden_size, inter_size / 2, num_experts, + weightonly_group_size, stream); } else if (gemm_method_ == "weight_only_int4") { int4_moe_gemm_runner_->moe_gemm( @@ -369,6 +374,7 @@ class MoeHelper { hidden_size, inter_size / 2, num_experts, + weightonly_group_size, stream); } else { fp16_moe_gemm_runner_->moe_gemm( @@ -381,6 +387,7 @@ class MoeHelper { hidden_size, inter_size / 2, num_experts, + weightonly_group_size, stream); } diff --git a/csrc/gpu/moe/fused_moe/moe_ffn.cu b/csrc/gpu/moe/fused_moe/moe_ffn.cu index 3c69c20e3df7..270f1985d92b 100644 --- a/csrc/gpu/moe/fused_moe/moe_ffn.cu +++ b/csrc/gpu/moe/fused_moe/moe_ffn.cu @@ -26,6 +26,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, const paddle::optional& ffn1_scale, const paddle::optional& ffn2_scale, const std::string& quant_method, + const int32_t weightonly_group_size, paddle::Tensor ffn_out) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -62,8 +63,16 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, ffn1_bias ? const_cast(ffn1_bias.get_ptr())->data() : nullptr; - + /* + group size + different quant and no quant, no quant and quant channel wise have same group + size no quant : group_size = -1 + quant channel wise : group_size = -1 + quant group wise : group_size = 64 || 128 + */ + int group_size = weightonly_group_size; if (quant_method == "weight_only_int8") { + group_size = group_size == -1 ? hidden_size : group_size; int8_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permuted_data), reinterpret_cast(ffn1_weight.data()), @@ -77,8 +86,10 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, num_experts, "none", + group_size, stream); } else if (quant_method == "weight_only_int4") { + group_size = group_size == -1 ? hidden_size : group_size; int4_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permuted_data), reinterpret_cast(ffn1_weight.data()), @@ -92,6 +103,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, num_experts, "none", + group_size, stream); } else { fp16_moe_gemm_runner.moe_gemm_bias_act( @@ -106,13 +118,16 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, num_experts, "none", + -1, stream); } auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr); auto act_out = act_out_tensor.data(); - + // reset group_size + group_size = weightonly_group_size; if (quant_method == "weight_only_int8") { + group_size = group_size == -1 ? inter_size / 2 : group_size; int8_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), reinterpret_cast(ffn2_weight.data()), @@ -124,9 +139,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, inter_size / 2, num_experts, + group_size, stream); } else if (quant_method == "weight_only_int4") { + group_size = group_size == -1 ? inter_size / 2 : group_size; int4_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), reinterpret_cast(ffn2_weight.data()), @@ -138,6 +155,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, inter_size / 2, num_experts, + group_size, stream); } else { fp16_moe_gemm_runner.moe_gemm( @@ -150,6 +168,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, inter_size / 2, num_experts, + -1, stream); } } @@ -162,7 +181,8 @@ std::vector MoeExpertFFN( const paddle::optional& ffn1_bias, const paddle::optional& ffn1_scale, const paddle::optional& ffn2_scale, - const std::string& quant_method) { + const std::string& quant_method, + const int32_t weightonly_group_size) { const auto input_type = permute_input.dtype(); auto ffn_out = paddle::empty_like(permute_input); @@ -176,6 +196,7 @@ std::vector MoeExpertFFN( ffn1_scale, ffn2_scale, quant_method, + weightonly_group_size, ffn_out); break; case paddle::DataType::FLOAT16: @@ -187,6 +208,7 @@ std::vector MoeExpertFFN( ffn1_scale, ffn2_scale, quant_method, + weightonly_group_size, ffn_out); break; default: @@ -226,7 +248,7 @@ PD_BUILD_OP(moe_expert_ffn) paddle::Optional("ffn1_scale"), paddle::Optional("ffn2_scale")}) .Outputs({"output_tensor"}) - .Attrs({"quant_method:std::string"}) + .Attrs({"quant_method:std::string", "weightonly_group_size: int"}) .SetKernelFn(PD_KERNEL(MoeExpertFFN)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype)); diff --git a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py index dca9623f0cb4..e106b2d003ac 100644 --- a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py @@ -675,17 +675,17 @@ def set_state_dict(self, state_dict): if self.use_weight_only: q_a_proj_quanted_weight, q_a_proj_weight_scale = weight_quantize( - q_a_proj_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + q_a_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.q_a_proj_weights[idx].set_value(q_a_proj_quanted_weight.cuda()) - self.transformer_block.q_a_proj_weights_scale[idx].set_value(q_a_proj_weight_scale.cuda()) + self.transformer_block.q_a_proj_weights[idx].set_value(q_a_proj_quanted_weight) + self.transformer_block.q_a_proj_weights_scale[idx].set_value(q_a_proj_weight_scale) q_b_proj_quanted_weight, q_b_proj_weight_scale = weight_quantize( - q_b_proj_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + q_b_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.q_b_proj_weights[idx].set_value(q_b_proj_quanted_weight.cuda()) + self.transformer_block.q_b_proj_weights[idx].set_value(q_b_proj_quanted_weight) self.transformer_block.q_a_layernorm_weights[idx].set_value(q_a_layernorm_weight) - self.transformer_block.q_b_proj_weights_scale[idx].set_value(q_b_proj_weight_scale.cuda()) + self.transformer_block.q_b_proj_weights_scale[idx].set_value(q_b_proj_weight_scale) elif "fp8" in self.quant_type: q_a_proj_quanted_weight = ( paddle.to_tensor( @@ -733,10 +733,10 @@ def set_state_dict(self, state_dict): if self.use_weight_only: q_proj_quanted_weight, q_proj_weight_scale = weight_quantize( - q_proj_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + q_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.q_proj_weights[idx].set_value(q_proj_quanted_weight.cuda()) - self.transformer_block.q_proj_weights_scale[idx].set_value(q_proj_weight_scale.cuda()) + self.transformer_block.q_proj_weights[idx].set_value(q_proj_quanted_weight) + self.transformer_block.q_proj_weights_scale[idx].set_value(q_proj_weight_scale) elif "fp8" in self.quant_type: q_proj_quanted_weight = ( paddle.to_tensor(state_dict[f"{self.base_model_prefix}.layers.{idx}.self_attn.q_proj.weight"]) @@ -804,21 +804,17 @@ def set_state_dict(self, state_dict): if self.use_weight_only: kv_a_proj_with_mqa_quanted_weight, kv_a_proj_with_mqa_weight_scale = weight_quantize( - kv_a_proj_with_mqa_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size - ) - self.transformer_block.kv_a_proj_with_mqa_weights[idx].set_value( - kv_a_proj_with_mqa_quanted_weight.cuda() - ) - self.transformer_block.kv_a_proj_with_mqa_weights_scale[idx].set_value( - kv_a_proj_with_mqa_weight_scale.cuda() + kv_a_proj_with_mqa_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) + self.transformer_block.kv_a_proj_with_mqa_weights[idx].set_value(kv_a_proj_with_mqa_quanted_weight) + self.transformer_block.kv_a_proj_with_mqa_weights_scale[idx].set_value(kv_a_proj_with_mqa_weight_scale) kv_b_proj_quanted_weight, kv_b_proj_weight_scale = weight_quantize( - kv_b_proj_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + kv_b_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.kv_b_proj_weights[idx].set_value(kv_b_proj_quanted_weight.cuda()) + self.transformer_block.kv_b_proj_weights[idx].set_value(kv_b_proj_quanted_weight) self.transformer_block.kv_a_layernorm_weights[idx].set_value(kv_a_layernorm_weight) - self.transformer_block.kv_b_proj_weights_scale[idx].set_value(kv_b_proj_weight_scale.cuda()) + self.transformer_block.kv_b_proj_weights_scale[idx].set_value(kv_b_proj_weight_scale) elif "fp8" in self.quant_type: kv_a_proj_with_mqa_quanted_weight = ( paddle.to_tensor( @@ -862,10 +858,10 @@ def set_state_dict(self, state_dict): if self.use_weight_only: linear_quanted_weight, linear_weight_scale = weight_quantize( - linear_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + linear_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight.cuda()) - self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale.cuda()) + self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight) + self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale) elif "fp8" in self.quant_type: linear_quanted_weight = ( paddle.to_tensor(state_dict[f"{self.base_model_prefix}.layers.{idx}.self_attn.o_proj.weight"]) @@ -902,10 +898,10 @@ def set_state_dict(self, state_dict): if self.use_weight_only: ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize( - ffn1_weight_tensor.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + ffn1_weight_tensor, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight_tensor.cuda()) - self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale_tensor.cuda()) + self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight_tensor) + self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale_tensor) elif "fp8" in self.quant_type: ffn1_quanted_weight_tensor = ( paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)).cast(paddle.float8_e4m3fn) @@ -935,10 +931,10 @@ def set_state_dict(self, state_dict): ).cast(paddle.get_default_dtype()) if self.use_weight_only: ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize( - ffn2_weight_tensor.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + ffn2_weight_tensor, algo=self.quant_algo, group_size=self.weightonly_group_size ) - self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight_tensor.cuda()) - self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale_tensor.cuda()) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight_tensor) + self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale_tensor) elif "fp8" in self.quant_type: ffn2_quanted_weight_tensor = ( paddle.to_tensor(state_dict[f"{self.base_model_prefix}.layers.{idx}.mlp.down_proj.weight"]) @@ -981,10 +977,10 @@ def set_state_dict(self, state_dict): if self.use_weight_only: ffn1_quanted_weight, ffn1_weight_scale = weight_quantize( - ffn1_weight, algo=self.quant_algo, group_size=-1 + ffn1_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) ffn2_quanted_weight, ffn2_weight_scale = weight_quantize( - ffn2_weight, algo=self.quant_algo, group_size=-1 + ffn2_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) ffn1_weights.append(ffn1_quanted_weight.reshape([self.transformer_block.config.embed_dim, -1])) ffn2_weights.append(ffn2_quanted_weight.reshape([-1, self.transformer_block.config.embed_dim])) @@ -1042,10 +1038,10 @@ def set_state_dict(self, state_dict): weight_block_size=self.weight_block_size, ) ffn1_quanted_weight, ffn1_weight_scale = weight_quantize( - ffn1_weight, algo=self.moe_quant_type, group_size=-1 + ffn1_weight, algo=self.moe_quant_type, group_size=self.weightonly_group_size ) ffn2_quanted_weight, ffn2_weight_scale = weight_quantize( - ffn2_weight, algo=self.moe_quant_type, group_size=-1 + ffn2_weight, algo=self.moe_quant_type, group_size=self.weightonly_group_size ) ffn1_weights.append( ffn1_quanted_weight.reshape([self.transformer_block.config.embed_dim, -1]) @@ -1170,7 +1166,7 @@ def set_state_dict(self, state_dict): if self.use_weight_only: shared_expert_ffn1_quanted_weight, shared_expert_ffn1_weight_scale = weight_quantize( - shared_expert_ffn1_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size + shared_expert_ffn1_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) self.transformer_block.shared_expert_ffn1_weights[idx].set_value(shared_expert_ffn1_quanted_weight) self.transformer_block.shared_expert_ffn1_weights_scale[idx].set_value( @@ -1178,14 +1174,13 @@ def set_state_dict(self, state_dict): ) shared_expert_ffn2_quanted_weight, shared_expert_ffn2_weight_scale = weight_quantize( - shared_expert_ffn2_weight.cpu(), algo=self.quant_algo, group_size=self.weightonly_group_size - ) - self.transformer_block.shared_expert_ffn2_weights[idx].set_value( - shared_expert_ffn2_quanted_weight.cuda() + shared_expert_ffn2_weight, algo=self.quant_algo, group_size=self.weightonly_group_size ) + self.transformer_block.shared_expert_ffn2_weights[idx].set_value(shared_expert_ffn2_quanted_weight) self.transformer_block.shared_expert_ffn2_weights_scale[idx].set_value( - shared_expert_ffn2_weight_scale.cuda() + shared_expert_ffn2_weight_scale ) + elif "fp8" in self.quant_type: shared_expert_ffn1_quanted_weight = ( paddle.to_tensor(concated_gate_up_weight).transpose((1, 0)).cast(paddle.float8_e4m3fn) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 300619109a9d..00123d7d4da7 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -1325,6 +1325,7 @@ def get_moe_scores( self.ffn1_weights_scale[i] if hasattr(self, "ffn1_weights_scale") else None, self.ffn2_weights_scale[i] if hasattr(self, "ffn2_weights_scale") else None, self.quant_type if hasattr(self, "quant_type") else "None", + self.config.weightonly_group_size, ) fused_moe_out = moe_expert_reduce( @@ -1347,6 +1348,7 @@ def get_moe_scores( self.ffn2_biases[i], self.ffn2_weights_scale[i] if hasattr(self, "ffn2_weights_scale") else None, self.quant_type if hasattr(self, "quant_type") else "None", + self.config.weightonly_group_size, self.config.moe_config.top_k, self.config.moe_config.norm_topk_prob, False, @@ -1781,10 +1783,30 @@ def __init__(self, config: FusedMultiTransformerConfig): ffn1_weight_scale_attr = self.get_attr(config.ffn1_weight_scale_attrs, i) ffn2_weight_scale_attr = self.get_attr(config.ffn2_weight_scale_attrs, i) if self.config.moe_config.use_moe(i): + if self.weightonly_group_size < 0: + base_shape = ( + [self.config.moe_config.num_experts, self.config.moe_config.moe_intermediate_size * 2] + if config.activation.endswith("glu") + else [self.config.moe_config.num_experts, self.config.moe_config.moe_intermediate_size] + ) + else: + base_shape_group = (self.embed_dim + self.weightonly_group_size - 1) // self.weightonly_group_size + base_shape = ( + [ + self.config.moe_config.num_experts, + base_shape_group, + self.config.moe_config.moe_intermediate_size * 2, + ] + if config.activation.endswith("glu") + else [ + self.config.moe_config.num_experts, + base_shape_group, + self.config.moe_config.moe_intermediate_size, + ] + ) + ffn1_weight_scale = self.create_parameter( - shape=[self.config.moe_config.num_experts, self.config.moe_config.moe_intermediate_size * 2] - if config.activation.endswith("glu") - else [self.config.moe_config.num_experts, self.config.moe_config.moe_intermediate_size], + shape=base_shape, attr=ffn1_weight_scale_attr, dtype=self.weight_scale_dtype, is_bias=False, @@ -1806,8 +1828,15 @@ def __init__(self, config: FusedMultiTransformerConfig): ) if self.config.moe_config.use_moe(i): + if self.weightonly_group_size < 0: + base_shape = [self.config.moe_config.num_experts, self.embed_dim] + else: + base_shape_group = ( + self.config.moe_config.moe_intermediate_size + self.weightonly_group_size - 1 + ) // self.weightonly_group_size + base_shape = [self.config.moe_config.num_experts, base_shape_group, self.embed_dim] ffn2_weight_scale = self.create_parameter( - shape=[self.config.moe_config.num_experts, self.embed_dim], + shape=base_shape, attr=ffn2_weight_scale_attr, dtype=self.weight_scale_dtype, is_bias=False, diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 401ee655ca98..31e7801e2044 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -109,12 +109,15 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str): self.use_fake_parameter = config.get("use_fake_parameter", False) self.use_weight_only = False + self.weightonly_group_size = -1 if config.quant_type == "weight_only_int8": self.use_weight_only = True self.quant_algo = "weight_only_int8" + self.weightonly_group_size = config.weightonly_group_size elif config.quant_type == "weight_only_int4": self.use_weight_only = True self.quant_algo = "weight_only_int4" + self.weightonly_group_size = config.weightonly_group_size elif "a8w8" in config.quant_type: self.quant_model_path = config.model_name_or_path self.shift = config.quantization_config.shift @@ -312,6 +315,7 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str): kv_num_heads=self.num_key_value_heads, intermediate_size=self.intermediate_size, quant_type=self.quant_type, + weightonly_group_size=self.weightonly_group_size, activation="swiglu", num_layers=config.num_hidden_layers, tp_degree=config.tensor_parallel_degree, @@ -673,7 +677,9 @@ def concat(tensor_list, axis=-1): if self.use_weight_only: qkv_weight = paddle.transpose(qkv_weight, perm=[1, 0]) - qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_algo) + qkv_quanted_weight, qkv_weight_scale = weight_quantize( + qkv_weight, algo=self.quant_algo, group_size=self.weightonly_group_size + ) self.transformer_block.qkv_weights[idx].copy_(qkv_quanted_weight, False) self.transformer_block.qkv_weights_scale[idx].copy_(qkv_weight_scale, False) elif "fp8" in self.quant_type: @@ -711,7 +717,9 @@ def concat(tensor_list, axis=-1): paddle.get_default_dtype() ) if self.use_weight_only: - linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_algo) + linear_quanted_weight, linear_weight_scale = weight_quantize( + linear_weight, algo=self.quant_algo, group_size=self.weightonly_group_size + ) self.transformer_block.linear_weights[idx].copy_(linear_quanted_weight, False) self.transformer_block.linear_weights_scale[idx].copy_(linear_weight_scale, False) elif "fp8" in self.quant_type: @@ -770,7 +778,9 @@ def concat(tensor_list, axis=-1): ffn1_weight = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: - ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_algo) + ffn1_quanted_weight, ffn1_weight_scale = weight_quantize( + ffn1_weight, algo=self.quant_algo, group_size=self.weightonly_group_size + ) self.transformer_block.ffn1_weights[idx].copy_(ffn1_quanted_weight, False) self.transformer_block.ffn1_weights_scale[idx].copy_(ffn1_weight_scale, False) elif "fp8" in self.quant_type: @@ -807,7 +817,9 @@ def concat(tensor_list, axis=-1): paddle.get_default_dtype() ) if self.use_weight_only: - ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_algo) + ffn2_quanted_weight, ffn2_weight_scale = weight_quantize( + ffn2_weight, algo=self.quant_algo, group_size=self.weightonly_group_size + ) self.transformer_block.ffn2_weights[idx].copy_(ffn2_quanted_weight, False) self.transformer_block.ffn2_weights_scale[idx].copy_(ffn2_weight_scale, False) elif "fp8" in self.quant_type: diff --git a/tests/llm/test_predictor_v1.py b/tests/llm/test_predictor_v1.py index 98cd7b55c51d..e540548ba1cb 100644 --- a/tests/llm/test_predictor_v1.py +++ b/tests/llm/test_predictor_v1.py @@ -187,3 +187,84 @@ def tearDown(self): LLMTest.tearDown(self) if os.path.exists(self.save_file_path): shutil.rmtree(self.save_file_path) + + +@parameterized_class( + ["model_name_or_path", "model_class"], + [ + ["deepseek-ai/DeepSeek-V2-Lite-Chat", AutoModelForCausalLM], + ], +) +class GroupWiseWeightQuantInferenceTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/predictor.yaml" + model_name_or_path: str = None + model_class = None + + def setUp(self) -> None: + super().setUp() + self.model_class.from_pretrained(self.model_name_or_path, dtype="float16").save_pretrained(self.output_dir) + AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir) + global global_result + model_tag = os.path.basename(self.model_name_or_path) + + if model_tag not in global_result: + self.run_predictor({"inference_model": True, "block_attn": True, "append_attn": True, "max_length": 48}) + self.golden_result = self._read_result(os.path.join(self.output_dir, "predict.json")) + global_result[model_tag] = self.golden_result + else: + self.golden_result = global_result[model_tag] + + @parameterized.expand( + [ + ( + { + "quant_type": "weight_only_int4", + "weightonly_group_size": 64, + }, + ), + ( + { + "quant_type": "weight_only_int4", + "weightonly_group_size": 128, + }, + ), + ( + { + "quant_type": "weight_only_int8", + "weightonly_group_size": 64, + }, + ), + ( + { + "quant_type": "weight_only_int8", + "weightonly_group_size": 128, + }, + ), + ] + ) + def test_groupwise_weight_quant_inference(self, param_case): + config_params = {"inference_model": True, "block_attn": True, "append_attn": True, "max_length": 48} + config_params.update(param_case) + print(config_params) + + self.run_predictor(config_params) + + result = self._read_result(os.path.join(self.output_dir, "predict.json")) + assert len(self.golden_result) == len(result) + + partial_match, full_match = 0, 0 + for golden_item, result_item in zip(self.golden_result, result): + score = levenshtein_similarity(golden_item, result_item) + if score >= 0.95: + full_match += 1 + if score >= 0.6: + partial_match += 1 + + if not config_params["inference_model"]: + self.assertGreaterEqual(full_match / len(self.golden_result), 0.3) + self.assertGreaterEqual(partial_match / len(self.golden_result), 0.4) + elif config_params.get("use_fake_parameter", False): + pass + else: + self.assertGreaterEqual(full_match / len(self.golden_result), 0.5) + self.assertGreaterEqual(partial_match / len(self.golden_result), 0.8)