Skip to content

Commit 3da5ba0

Browse files
committed
Simplify binary operation definition for fp16 and bf16
1 parent 283edce commit 3da5ba0

File tree

3 files changed

+3
-38
lines changed

3 files changed

+3
-38
lines changed

include/kernel_float/bf16.h

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
176176
}; \
177177
}
178178
#else
179-
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
180-
namespace ops { \
181-
template<> \
182-
struct NAME<__nv_bfloat16> { \
183-
KERNEL_FLOAT_INLINE __nv_bfloat16 \
184-
operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \
185-
return __nv_bfloat16(ops::NAME<float> {}(float(left), float(right))); \
186-
} \
187-
}; \
188-
}
179+
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2)
189180
#endif
190181

191182
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
@@ -205,20 +196,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
205196
KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)
206197

207198
namespace ops {
208-
template<typename T>
209-
struct cast<T, __nv_bfloat16> {
210-
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) {
211-
return __float2bfloat16(ops::cast<T, float> {}(input));
212-
};
213-
};
214-
215-
template<typename T>
216-
struct cast<__nv_bfloat16, T> {
217-
KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) {
218-
return ops::cast<float, T> {}(__bfloat162float(input));
219-
};
220-
};
221-
222199
template<>
223200
struct cast<double, __nv_bfloat16> {
224201
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) {
@@ -340,10 +317,6 @@ struct dot_impl<__nv_bfloat16, N> {
340317
#include "fp16.h"
341318

342319
namespace kernel_float {
343-
#if KERNEL_FLOAT_CUDA_ARCH >= 800
344-
KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input));
345-
#endif
346-
347320
template<>
348321
struct promote_type<__nv_bfloat16, __half> {
349322
using type = float;

include/kernel_float/binops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
9191
template<typename T> \
9292
struct NAME { \
9393
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
94-
return T(EXPR); \
94+
return ops::cast<decltype(EXPR), T> {}(EXPR); \
9595
} \
9696
}; \
9797
} \

include/kernel_float/fp16.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
170170
}; \
171171
}
172172
#else
173-
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
174-
namespace ops { \
175-
template<> \
176-
struct NAME<__half> { \
177-
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
178-
return __half(ops::NAME<float> {}(float(left), float(right))); \
179-
} \
180-
}; \
181-
}
173+
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2)
182174
#endif
183175

184176
KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2)

0 commit comments

Comments
 (0)