Skip to content

Commit 283edce

Browse files
committed
Add cast_float_fallback mechanism
1 parent 4824f97 commit 283edce

File tree

4 files changed

+176
-83
lines changed

4 files changed

+176
-83
lines changed

include/kernel_float/fp8.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,39 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
2828
static constexpr bool value = true;
2929
};
3030
} // namespace detail
31+
32+
#define KERNEL_FLOAT_FP8_CAST(T) \
33+
namespace ops { \
34+
template<> \
35+
struct cast<T, __nv_fp8_e4m3> { \
36+
KERNEL_FLOAT_INLINE __nv_fp8_e4m3 operator()(T v) const { \
37+
return __nv_fp8_e4m3(v); \
38+
} \
39+
}; \
40+
\
41+
template<> \
42+
struct cast<__nv_fp8_e4m3, T> { \
43+
KERNEL_FLOAT_INLINE T operator()(__nv_fp8_e4m3 v) const { \
44+
return T(v); \
45+
} \
46+
}; \
47+
\
48+
template<> \
49+
struct cast<T, __nv_fp8_e5m2> { \
50+
KERNEL_FLOAT_INLINE __nv_fp8_e5m2 operator()(T v) const { \
51+
return __nv_fp8_e5m2(v); \
52+
} \
53+
}; \
54+
\
55+
template<> \
56+
struct cast<__nv_fp8_e5m2, T> { \
57+
KERNEL_FLOAT_INLINE T operator()(__nv_fp8_e5m2 v) const { \
58+
return T(v); \
59+
} \
60+
}; \
61+
}
62+
63+
KERNEL_FLOAT_FP8_CAST(double)
3164
} // namespace kernel_float
3265

3366
#if KERNEL_FLOAT_FP16_AVAILABLE
@@ -36,6 +69,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
3669
namespace kernel_float {
3770
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
3871
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
72+
KERNEL_FLOAT_FP8_CAST(__half)
3973
} // namespace kernel_float
4074
#endif // KERNEL_FLOAT_FP16_AVAILABLE
4175

@@ -45,6 +79,7 @@ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
4579
namespace kernel_float {
4680
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
4781
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
82+
KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
4883
} // namespace kernel_float
4984
#endif // KERNEL_FLOAT_BF16_AVAILABLE
5085

include/kernel_float/unops.h

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,27 @@ KERNEL_FLOAT_INLINE map_type<F, V> map(F fun, const V& input) {
4444
return result;
4545
}
4646

47+
namespace detail {
48+
// Indicates that elements of type `T` offer less precision than floats, thus operations
49+
// on elements of type `T` can be performed by upcasting them to ` float`.
50+
template<typename T>
51+
struct allow_float_fallback {
52+
static constexpr bool value = false;
53+
};
54+
55+
template<>
56+
struct allow_float_fallback<float> {
57+
static constexpr bool value = true;
58+
};
59+
} // namespace detail
60+
4761
enum struct RoundingMode { ANY, DOWN, UP, NEAREST, TOWARD_ZERO };
4862

4963
namespace ops {
64+
5065
template<typename T, typename R, RoundingMode m = RoundingMode::ANY, typename = void>
5166
struct cast;
5267

53-
template<typename T, typename R>
54-
struct cast<T, R, RoundingMode::ANY> {
55-
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
56-
return R(input);
57-
}
58-
};
59-
6068
template<typename T, RoundingMode m>
6169
struct cast<T, T, m> {
6270
KERNEL_FLOAT_INLINE T operator()(T input) noexcept {
@@ -70,6 +78,41 @@ struct cast<T, T, RoundingMode::ANY> {
7078
return input;
7179
}
7280
};
81+
82+
template<typename T, typename R, typename = void>
83+
struct cast_float_fallback;
84+
85+
template<typename T, typename R>
86+
struct cast<T, R, RoundingMode::ANY> {
87+
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
88+
return cast_float_fallback<T, R> {}(input);
89+
}
90+
};
91+
92+
template<typename T, typename R, typename>
93+
struct cast_float_fallback {
94+
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
95+
return R(input);
96+
}
97+
};
98+
99+
// clang-format off
100+
template<typename T, typename R>
101+
struct cast_float_fallback<
102+
T,
103+
R,
104+
enable_if_t<
105+
!is_same_type<T, float> &&
106+
!is_same_type<R, float> &&
107+
(detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value)
108+
>
109+
> {
110+
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
111+
return cast<float, R> {}(cast<T, float> {}(input));
112+
}
113+
};
114+
// clang-format on
115+
73116
} // namespace ops
74117

75118
/**
@@ -125,20 +168,6 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(negate, -, -input)
125168
KERNEL_FLOAT_DEFINE_UNARY_OP(bit_not, ~, ~input)
126169
KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast<bool, T> {}(!ops::cast<T, bool> {}(input))))
127170

128-
namespace detail {
129-
// Indicates that elements of type `T` offer less precision than floats, thus operations
130-
// on elements of type `T` can be performed by upcasting them to ` float`.
131-
template<typename T>
132-
struct allow_float_fallback {
133-
static constexpr bool value = false;
134-
};
135-
136-
template<>
137-
struct allow_float_fallback<float> {
138-
static constexpr bool value = true;
139-
};
140-
} // namespace detail
141-
142171
#define KERNEL_FLOAT_DEFINE_UNARY_MATH(NAME) \
143172
namespace ops { \
144173
template<typename T, typename = void> \

single_include/kernel_float.h

Lines changed: 90 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2023-10-12 17:25:02.978518
20-
// git hash: 25f9bb64a14ef5a93b356d6089becd7139a0141f
19+
// date: 2023-10-12 19:42:20.177310
20+
// git hash: 4824f9787b219562d394b19c74f701ff75d8fb56
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -892,19 +892,27 @@ KERNEL_FLOAT_INLINE map_type<F, V> map(F fun, const V& input) {
892892
return result;
893893
}
894894

895+
namespace detail {
896+
// Indicates that elements of type `T` offer less precision than floats, thus operations
897+
// on elements of type `T` can be performed by upcasting them to ` float`.
898+
template<typename T>
899+
struct allow_float_fallback {
900+
static constexpr bool value = false;
901+
};
902+
903+
template<>
904+
struct allow_float_fallback<float> {
905+
static constexpr bool value = true;
906+
};
907+
} // namespace detail
908+
895909
enum struct RoundingMode { ANY, DOWN, UP, NEAREST, TOWARD_ZERO };
896910

897911
namespace ops {
912+
898913
template<typename T, typename R, RoundingMode m = RoundingMode::ANY, typename = void>
899914
struct cast;
900915

901-
template<typename T, typename R>
902-
struct cast<T, R, RoundingMode::ANY> {
903-
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
904-
return R(input);
905-
}
906-
};
907-
908916
template<typename T, RoundingMode m>
909917
struct cast<T, T, m> {
910918
KERNEL_FLOAT_INLINE T operator()(T input) noexcept {
@@ -918,6 +926,41 @@ struct cast<T, T, RoundingMode::ANY> {
918926
return input;
919927
}
920928
};
929+
930+
template<typename T, typename R, typename = void>
931+
struct cast_float_fallback;
932+
933+
template<typename T, typename R>
934+
struct cast<T, R, RoundingMode::ANY> {
935+
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
936+
return cast_float_fallback<T, R> {}(input);
937+
}
938+
};
939+
940+
template<typename T, typename R, typename>
941+
struct cast_float_fallback {
942+
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
943+
return R(input);
944+
}
945+
};
946+
947+
// clang-format off
948+
template<typename T, typename R>
949+
struct cast_float_fallback<
950+
T,
951+
R,
952+
enable_if_t<
953+
!is_same_type<T, float> &&
954+
!is_same_type<R, float> &&
955+
(detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value)
956+
>
957+
> {
958+
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
959+
return cast<float, R> {}(cast<T, float> {}(input));
960+
}
961+
};
962+
// clang-format on
963+
921964
} // namespace ops
922965

923966
/**
@@ -973,20 +1016,6 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(negate, -, -input)
9731016
KERNEL_FLOAT_DEFINE_UNARY_OP(bit_not, ~, ~input)
9741017
KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast<bool, T> {}(!ops::cast<T, bool> {}(input))))
9751018

976-
namespace detail {
977-
// Indicates that elements of type `T` offer less precision than floats, thus operations
978-
// on elements of type `T` can be performed by upcasting them to ` float`.
979-
template<typename T>
980-
struct allow_float_fallback {
981-
static constexpr bool value = false;
982-
};
983-
984-
template<>
985-
struct allow_float_fallback<float> {
986-
static constexpr bool value = true;
987-
};
988-
} // namespace detail
989-
9901019
#define KERNEL_FLOAT_DEFINE_UNARY_MATH(NAME) \
9911020
namespace ops { \
9921021
template<typename T, typename = void> \
@@ -1460,7 +1489,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
14601489
template<typename T> \
14611490
struct NAME { \
14621491
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
1463-
return T(EXPR); \
1492+
return ops::cast<decltype(EXPR), T> {}(EXPR); \
14641493
} \
14651494
}; \
14661495
} \
@@ -3497,15 +3526,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
34973526
}; \
34983527
}
34993528
#else
3500-
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
3501-
namespace ops { \
3502-
template<> \
3503-
struct NAME<__half> { \
3504-
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
3505-
return __half(ops::NAME<float> {}(float(left), float(right))); \
3506-
} \
3507-
}; \
3508-
}
3529+
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2)
35093530
#endif
35103531

35113532
KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2)
@@ -3793,16 +3814,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
37933814
}; \
37943815
}
37953816
#else
3796-
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
3797-
namespace ops { \
3798-
template<> \
3799-
struct NAME<__nv_bfloat16> { \
3800-
KERNEL_FLOAT_INLINE __nv_bfloat16 \
3801-
operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \
3802-
return __nv_bfloat16(ops::NAME<float> {}(float(left), float(right))); \
3803-
} \
3804-
}; \
3805-
}
3817+
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2)
38063818
#endif
38073819

38083820
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
@@ -3822,20 +3834,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
38223834
KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)
38233835

38243836
namespace ops {
3825-
template<typename T>
3826-
struct cast<T, __nv_bfloat16> {
3827-
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) {
3828-
return __float2bfloat16(ops::cast<T, float> {}(input));
3829-
};
3830-
};
3831-
3832-
template<typename T>
3833-
struct cast<__nv_bfloat16, T> {
3834-
KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) {
3835-
return ops::cast<float, T> {}(__bfloat162float(input));
3836-
};
3837-
};
3838-
38393837
template<>
38403838
struct cast<double, __nv_bfloat16> {
38413839
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) {
@@ -3957,10 +3955,6 @@ struct dot_impl<__nv_bfloat16, N> {
39573955

39583956

39593957
namespace kernel_float {
3960-
#if KERNEL_FLOAT_CUDA_ARCH >= 800
3961-
KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input));
3962-
#endif
3963-
39643958
template<>
39653959
struct promote_type<__nv_bfloat16, __half> {
39663960
using type = float;
@@ -4007,6 +4001,39 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
40074001
static constexpr bool value = true;
40084002
};
40094003
} // namespace detail
4004+
4005+
#define KERNEL_FLOAT_FP8_CAST(T) \
4006+
namespace ops { \
4007+
template<> \
4008+
struct cast<T, __nv_fp8_e4m3> { \
4009+
KERNEL_FLOAT_INLINE __nv_fp8_e4m3 operator()(T v) const { \
4010+
return __nv_fp8_e4m3(v); \
4011+
} \
4012+
}; \
4013+
\
4014+
template<> \
4015+
struct cast<__nv_fp8_e4m3, T> { \
4016+
KERNEL_FLOAT_INLINE T operator()(__nv_fp8_e4m3 v) const { \
4017+
return T(v); \
4018+
} \
4019+
}; \
4020+
\
4021+
template<> \
4022+
struct cast<T, __nv_fp8_e5m2> { \
4023+
KERNEL_FLOAT_INLINE __nv_fp8_e5m2 operator()(T v) const { \
4024+
return __nv_fp8_e5m2(v); \
4025+
} \
4026+
}; \
4027+
\
4028+
template<> \
4029+
struct cast<__nv_fp8_e5m2, T> { \
4030+
KERNEL_FLOAT_INLINE T operator()(__nv_fp8_e5m2 v) const { \
4031+
return T(v); \
4032+
} \
4033+
}; \
4034+
}
4035+
4036+
KERNEL_FLOAT_FP8_CAST(double)
40104037
} // namespace kernel_float
40114038

40124039
#if KERNEL_FLOAT_FP16_AVAILABLE
@@ -4015,6 +4042,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
40154042
namespace kernel_float {
40164043
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
40174044
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
4045+
KERNEL_FLOAT_FP8_CAST(__half)
40184046
} // namespace kernel_float
40194047
#endif // KERNEL_FLOAT_FP16_AVAILABLE
40204048

@@ -4024,6 +4052,7 @@ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
40244052
namespace kernel_float {
40254053
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
40264054
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
4055+
KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
40274056
} // namespace kernel_float
40284057
#endif // KERNEL_FLOAT_BF16_AVAILABLE
40294058

0 commit comments

Comments
 (0)