16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2023-10-12 17:25:02.978518
20
- // git hash: 25f9bb64a14ef5a93b356d6089becd7139a0141f
19
+ // date: 2023-10-12 19:42:20.177310
20
+ // git hash: 4824f9787b219562d394b19c74f701ff75d8fb56
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
@@ -892,19 +892,27 @@ KERNEL_FLOAT_INLINE map_type<F, V> map(F fun, const V& input) {
892
892
return result;
893
893
}
894
894
895
+ namespace detail {
896
+ // Indicates that elements of type `T` offer less precision than floats, thus operations
897
+ // on elements of type `T` can be performed by upcasting them to ` float`.
898
+ template <typename T>
899
+ struct allow_float_fallback {
900
+ static constexpr bool value = false ;
901
+ };
902
+
903
+ template <>
904
+ struct allow_float_fallback <float > {
905
+ static constexpr bool value = true ;
906
+ };
907
+ } // namespace detail
908
+
895
909
enum struct RoundingMode { ANY, DOWN, UP, NEAREST, TOWARD_ZERO };
896
910
897
911
namespace ops {
912
+
898
913
template <typename T, typename R, RoundingMode m = RoundingMode::ANY, typename = void >
899
914
struct cast ;
900
915
901
- template <typename T, typename R>
902
- struct cast <T, R, RoundingMode::ANY> {
903
- KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
904
- return R (input);
905
- }
906
- };
907
-
908
916
template <typename T, RoundingMode m>
909
917
struct cast <T, T, m> {
910
918
KERNEL_FLOAT_INLINE T operator ()(T input) noexcept {
@@ -918,6 +926,41 @@ struct cast<T, T, RoundingMode::ANY> {
918
926
return input;
919
927
}
920
928
};
929
+
930
+ template <typename T, typename R, typename = void >
931
+ struct cast_float_fallback ;
932
+
933
+ template <typename T, typename R>
934
+ struct cast <T, R, RoundingMode::ANY> {
935
+ KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
936
+ return cast_float_fallback<T, R> {}(input);
937
+ }
938
+ };
939
+
940
+ template <typename T, typename R, typename >
941
+ struct cast_float_fallback {
942
+ KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
943
+ return R (input);
944
+ }
945
+ };
946
+
947
+ // clang-format off
948
+ template <typename T, typename R>
949
+ struct cast_float_fallback <
950
+ T,
951
+ R,
952
+ enable_if_t <
953
+ !is_same_type<T, float > &&
954
+ !is_same_type<R, float > &&
955
+ (detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value)
956
+ >
957
+ > {
958
+ KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
959
+ return cast<float , R> {}(cast<T, float > {}(input));
960
+ }
961
+ };
962
+ // clang-format on
963
+
921
964
} // namespace ops
922
965
923
966
/* *
@@ -973,20 +1016,6 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(negate, -, -input)
973
1016
KERNEL_FLOAT_DEFINE_UNARY_OP (bit_not, ~, ~input)
974
1017
KERNEL_FLOAT_DEFINE_UNARY_OP (logical_not, !, (ops::cast<bool , T> {}(!ops::cast<T, bool > {}(input))))
975
1018
976
- namespace detail {
977
- // Indicates that elements of type `T` offer less precision than floats, thus operations
978
- // on elements of type `T` can be performed by upcasting them to ` float`.
979
- template <typename T>
980
- struct allow_float_fallback {
981
- static constexpr bool value = false ;
982
- };
983
-
984
- template <>
985
- struct allow_float_fallback <float > {
986
- static constexpr bool value = true ;
987
- };
988
- } // namespace detail
989
-
990
1019
#define KERNEL_FLOAT_DEFINE_UNARY_MATH (NAME ) \
991
1020
namespace ops { \
992
1021
template <typename T, typename = void > \
@@ -1460,7 +1489,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
1460
1489
template <typename T> \
1461
1490
struct NAME { \
1462
1491
KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
1463
- return T (EXPR); \
1492
+ return ops::cast< decltype (EXPR), T> {}(EXPR); \
1464
1493
} \
1465
1494
}; \
1466
1495
} \
@@ -3497,15 +3526,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
3497
3526
}; \
3498
3527
}
3499
3528
#else
3500
- #define KERNEL_FLOAT_FP16_BINARY_FUN (NAME, FUN1, FUN2 ) \
3501
- namespace ops { \
3502
- template <> \
3503
- struct NAME <__half> { \
3504
- KERNEL_FLOAT_INLINE __half operator ()(__half left, __half right) const { \
3505
- return __half (ops::NAME<float > {}(float (left), float (right))); \
3506
- } \
3507
- }; \
3508
- }
3529
+ #define KERNEL_FLOAT_FP16_BINARY_FUN (NAME, FUN1, FUN2 )
3509
3530
#endif
3510
3531
3511
3532
KERNEL_FLOAT_FP16_BINARY_FUN (add, __hadd, __hadd2)
@@ -3793,16 +3814,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
3793
3814
}; \
3794
3815
}
3795
3816
#else
3796
- #define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 ) \
3797
- namespace ops { \
3798
- template <> \
3799
- struct NAME <__nv_bfloat16> { \
3800
- KERNEL_FLOAT_INLINE __nv_bfloat16 \
3801
- operator ()(__nv_bfloat16 left, __nv_bfloat16 right) const { \
3802
- return __nv_bfloat16 (ops::NAME<float > {}(float (left), float (right))); \
3803
- } \
3804
- }; \
3805
- }
3817
+ #define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 )
3806
3818
#endif
3807
3819
3808
3820
KERNEL_FLOAT_BF16_BINARY_FUN (add, __hadd, __hadd2)
@@ -3822,20 +3834,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
3822
3834
KERNEL_FLOAT_BF16_BINARY_FUN (greater_equal, __hge, __hgt2)
3823
3835
3824
3836
namespace ops {
3825
- template <typename T>
3826
- struct cast <T, __nv_bfloat16> {
3827
- KERNEL_FLOAT_INLINE __nv_bfloat16 operator ()(T input) {
3828
- return __float2bfloat16 (ops::cast<T, float > {}(input));
3829
- };
3830
- };
3831
-
3832
- template <typename T>
3833
- struct cast <__nv_bfloat16, T> {
3834
- KERNEL_FLOAT_INLINE T operator ()(__nv_bfloat16 input) {
3835
- return ops::cast<float , T> {}(__bfloat162float (input));
3836
- };
3837
- };
3838
-
3839
3837
template <>
3840
3838
struct cast <double , __nv_bfloat16> {
3841
3839
KERNEL_FLOAT_INLINE __nv_bfloat16 operator ()(double input) {
@@ -3957,10 +3955,6 @@ struct dot_impl<__nv_bfloat16, N> {
3957
3955
3958
3956
3959
3957
namespace kernel_float {
3960
- #if KERNEL_FLOAT_CUDA_ARCH >= 800
3961
- KERNEL_FLOAT_BF16_CAST (__half, __float2bfloat16(input), __bfloat162float(input));
3962
- #endif
3963
-
3964
3958
template <>
3965
3959
struct promote_type <__nv_bfloat16, __half> {
3966
3960
using type = float ;
@@ -4007,6 +4001,39 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
4007
4001
static constexpr bool value = true ;
4008
4002
};
4009
4003
} // namespace detail
4004
+
4005
+ #define KERNEL_FLOAT_FP8_CAST (T ) \
4006
+ namespace ops { \
4007
+ template <> \
4008
+ struct cast <T, __nv_fp8_e4m3> { \
4009
+ KERNEL_FLOAT_INLINE __nv_fp8_e4m3 operator ()(T v) const { \
4010
+ return __nv_fp8_e4m3 (v); \
4011
+ } \
4012
+ }; \
4013
+ \
4014
+ template <> \
4015
+ struct cast <__nv_fp8_e4m3, T> { \
4016
+ KERNEL_FLOAT_INLINE T operator ()(__nv_fp8_e4m3 v) const { \
4017
+ return T (v); \
4018
+ } \
4019
+ }; \
4020
+ \
4021
+ template <> \
4022
+ struct cast <T, __nv_fp8_e5m2> { \
4023
+ KERNEL_FLOAT_INLINE __nv_fp8_e5m2 operator ()(T v) const { \
4024
+ return __nv_fp8_e5m2 (v); \
4025
+ } \
4026
+ }; \
4027
+ \
4028
+ template <> \
4029
+ struct cast <__nv_fp8_e5m2, T> { \
4030
+ KERNEL_FLOAT_INLINE T operator ()(__nv_fp8_e5m2 v) const { \
4031
+ return T (v); \
4032
+ } \
4033
+ }; \
4034
+ }
4035
+
4036
+ KERNEL_FLOAT_FP8_CAST (double )
4010
4037
} // namespace kernel_float
4011
4038
4012
4039
#if KERNEL_FLOAT_FP16_AVAILABLE
@@ -4015,6 +4042,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
4015
4042
namespace kernel_float {
4016
4043
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e4m3)
4017
4044
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e5m2)
4045
+ KERNEL_FLOAT_FP8_CAST (__half)
4018
4046
} // namespace kernel_float
4019
4047
#endif // KERNEL_FLOAT_FP16_AVAILABLE
4020
4048
@@ -4024,6 +4052,7 @@ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
4024
4052
namespace kernel_float {
4025
4053
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e4m3)
4026
4054
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e5m2)
4055
+ KERNEL_FLOAT_FP8_CAST (__nv_bfloat16)
4027
4056
} // namespace kernel_float
4028
4057
#endif // KERNEL_FLOAT_BF16_AVAILABLE
4029
4058
0 commit comments