@@ -67,29 +67,20 @@ struct LayoutDetailsB<TypeB, arch::Sm70> {
6767 using Operator = cutlass::arch::OpMultiplyAdd;
6868};
6969
70- // Specializations for Turing+ when B is FP16. These are currently only used for
71- // MoE networks.
72- template <typename Arch>
73- struct LayoutDetailsB <
74- half_t ,
75- Arch,
76- typename platform::enable_if<Arch::kMinComputeCapability >= 75 >::type> {
77- static constexpr int ThreadblockK = 64 ;
78- using Layout = layout::RowMajor;
79- static constexpr int ElementsPerAccess =
80- 128 / cutlass::sizeof_bits<half_t >::value;
81- using Operator = cutlass::arch::OpMultiplyAdd;
82- };
70+ // Specializations for Turing+ when B is 16 bit. These are currently only used
71+ // for MoE networks.
8372
84- template <typename Arch>
73+ template <typename TypeB, typename Arch>
8574struct LayoutDetailsB <
86- bfloat16_t ,
75+ TypeB ,
8776 Arch,
88- typename platform::enable_if<Arch::kMinComputeCapability >= 75 >::type> {
77+ typename platform::enable_if<
78+ Arch::kMinComputeCapability >= 75 &&
79+ (platform::is_same<TypeB, half_t >::value ||
80+ platform::is_same<TypeB, bfloat16_t >::value)>::type> {
8981 static constexpr int ThreadblockK = 64 ;
9082 using Layout = layout::RowMajor;
91- static constexpr int ElementsPerAccess =
92- 128 / cutlass::sizeof_bits<bfloat16_t >::value;
83+ static constexpr int ElementsPerAccess = 8 ;
9384 using Operator = cutlass::arch::OpMultiplyAdd;
9485};
9586
0 commit comments