Skip to content

Commit 9385655

Browse files
committed
Changes to make code compile under HIPRTC
1 parent 5490ea7 commit 9385655

File tree

10 files changed

+270
-249
lines changed

10 files changed

+270
-249
lines changed

include/kernel_float/base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ using promoted_vector_value_type = promote_t<vector_value_type<Vs>...>;
270270

271271
template<typename V>
272272
KERNEL_FLOAT_INLINE vector_storage_type<V> into_vector_storage(V&& input) {
273-
return into_vector_impl<V>::call(std::forward<V>(input));
273+
return into_vector_impl<V>::call(static_cast<V&&>(input));
274274
}
275275

276276
} // namespace kernel_float

include/kernel_float/bf16.h

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
#include "macros.h"
55

66
#if KERNEL_FLOAT_BF16_AVAILABLE
7+
//#define CUDA_NO_BFLOAT16 (1)
8+
//#define __CUDA_NO_BFLOAT16_OPERATORS__ (1)
9+
//#define __CUDA_NO_BFLOAT162_OPERATORS__ (1)
10+
//#define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1)
11+
712
#if KERNEL_FLOAT_IS_CUDA
813
#include <cuda_bf16.h>
914
#elif KERNEL_FLOAT_IS_HIP
@@ -76,21 +81,24 @@ struct allow_float_fallback<__bfloat16> {
7681
}; \
7782
}
7883

79-
KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2)
80-
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
81-
KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
84+
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
8285
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
86+
8387
KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp)
8488
KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10)
85-
KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor)
8689
KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log)
8790
KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2)
88-
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
89-
KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
90-
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
91+
9192
KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt)
92-
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
93+
KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
9394
KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
95+
96+
KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2)
97+
KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor)
98+
KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
99+
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
100+
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
101+
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
94102
#endif
95103

96104
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
@@ -99,7 +107,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
99107
template<> \
100108
struct NAME<__bfloat16> { \
101109
KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \
102-
return FUN1(left, right); \
110+
return ops::cast<decltype(FUN1(left, right)), __bfloat16> {}(FUN1(left, right)); \
103111
} \
104112
}; \
105113
} \
@@ -159,29 +167,6 @@ struct apply_impl<ops::fma<__bfloat16>, 2, __bfloat16, __bfloat16, __bfloat16, _
159167
} // namespace detail
160168
#endif
161169

162-
namespace ops {
163-
template<>
164-
struct cast<double, __bfloat16> {
165-
KERNEL_FLOAT_INLINE __bfloat16 operator()(double input) {
166-
return __double2bfloat16(input);
167-
};
168-
};
169-
170-
template<>
171-
struct cast<float, __bfloat16> {
172-
KERNEL_FLOAT_INLINE __bfloat16 operator()(float input) {
173-
return __float2bfloat16(input);
174-
};
175-
};
176-
177-
template<>
178-
struct cast<__bfloat16, float> {
179-
KERNEL_FLOAT_INLINE float operator()(__bfloat16 input) {
180-
return __bfloat162float(input);
181-
};
182-
};
183-
} // namespace ops
184-
185170
#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \
186171
namespace ops { \
187172
template<> \
@@ -198,31 +183,33 @@ struct cast<__bfloat16, float> {
198183
}; \
199184
}
200185

186+
KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input))
187+
KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input))
188+
201189
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
202190
// clang-format off
203191
// there are no official char casts. Instead, cast to int and then to char
204192
KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input));
205193
KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input));
206194
KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input));
207195

208-
KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input));
209-
KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input));
196+
KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input));
197+
KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input));
210198
KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input)));
211199
KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input));
212200

213-
KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input));
214-
KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input));
201+
KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input));
202+
KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input));
215203
KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input)));
216204
KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
217205
// clang-format on
218206
#endif
219207

220208
#if KERNEL_FLOAT_IS_CUDA
221-
KERNEL_FLOAT_BF16_CAST(
222-
bool,
223-
__nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
224-
(__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
225-
209+
//KERNEL_FLOAT_BF16_CAST(
210+
// bool,
211+
// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
212+
// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
226213
#elif KERNEL_FLOAT_IS_HIP
227214
KERNEL_FLOAT_BF16_CAST(
228215
bool,

include/kernel_float/binops.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
6565
return result;
6666
}
6767

68-
#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
69-
template<typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
70-
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
71-
return zip_common(ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
68+
#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
69+
template<typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
70+
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
71+
return zip_common(ops::NAME<C> {}, static_cast<L&&>(left), static_cast<R&&>(right)); \
7272
}
7373

7474
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \

include/kernel_float/fp16.h

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
#include "macros.h"
55

66
#if KERNEL_FLOAT_FP16_AVAILABLE
7+
//#define CUDA_NO_HALF (1)
8+
//#define __CUDA_NO_HALF_OPERATORS__ (1)
9+
//#define __CUDA_NO_HALF2_OPERATORS__ (1)
10+
//#define __CUDA_NO_HALF_CONVERSIONS__ (1)
11+
712
#if KERNEL_FLOAT_IS_CUDA
813
#include <cuda_fp16.h>
914
#elif KERNEL_FLOAT_IS_HIP
@@ -64,41 +69,44 @@ struct allow_float_fallback<__half> {
6469
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2)
6570
#endif
6671

67-
KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2)
68-
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
69-
KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
72+
KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin)
7073
KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos)
74+
7175
KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp)
7276
KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10)
73-
KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor)
7477
KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log)
7578
KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2)
76-
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
77-
KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt)
78-
KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin)
79+
7980
KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt)
80-
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
81+
KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt)
8182
KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp)
8283

84+
KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2)
85+
KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor)
86+
KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
87+
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
88+
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
89+
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
90+
8391
#if KERNEL_FLOAT_IS_DEVICE
84-
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
85-
namespace ops { \
86-
template<> \
87-
struct NAME<__half> { \
88-
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
89-
return FUN1(left, right); \
90-
} \
91-
}; \
92-
} \
93-
namespace detail { \
94-
template<> \
95-
struct apply_impl<ops::NAME<__half>, 2, __half, __half, __half> { \
96-
KERNEL_FLOAT_INLINE static void \
97-
call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \
98-
__half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \
99-
result[0] = r.x, result[1] = r.y; \
100-
} \
101-
}; \
92+
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
93+
namespace ops { \
94+
template<> \
95+
struct NAME<__half> { \
96+
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
97+
return ops::cast<decltype(FUN1(left, right)), __half> {}(FUN1(left, right)); \
98+
} \
99+
}; \
100+
} \
101+
namespace detail { \
102+
template<> \
103+
struct apply_impl<ops::NAME<__half>, 2, __half, __half, __half> { \
104+
KERNEL_FLOAT_INLINE static void \
105+
call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \
106+
__half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \
107+
result[0] = r.x, result[1] = r.y; \
108+
} \
109+
}; \
102110
}
103111
#else
104112
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2)
@@ -190,13 +198,13 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input));
190198
KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input));
191199
KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input));
192200

193-
KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input));
194-
KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input));
201+
KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input));
202+
KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input));
195203
KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input)));
196204
KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input));
197205

198-
KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input));
199-
KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input));
206+
KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input));
207+
KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input));
200208
KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input)));
201209
KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input));
202210
#endif

include/kernel_float/meta.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,6 @@ struct enable_if_impl<true, T> {
270270
template<bool C, typename T = void>
271271
using enable_if_t = typename detail::enable_if_impl<C, T>::type;
272272

273-
template<typename T, typename...>
274-
using identity_t = T;
275-
276273
KERNEL_FLOAT_INLINE
277274
constexpr size_t round_up_to_power_of_two(size_t n) {
278275
size_t result = 1;

include/kernel_float/prelude.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
6767
#endif
6868

6969
#if KERNEL_FLOAT_BF16_AVAILABLE
70-
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16)
71-
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16)
70+
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16)
71+
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16)
7272
#endif
7373

7474
#if KERNEL_FLOAT_BF8_AVAILABLE
@@ -82,12 +82,12 @@ static constexpr extent<N> kextent = {};
8282

8383
template<typename... Args>
8484
KERNEL_FLOAT_INLINE kvec<promote_t<Args...>, sizeof...(Args)> make_kvec(Args&&... args) {
85-
return ::kernel_float::make_vec(std::forward<Args>(args)...);
85+
return ::kernel_float::make_vec(static_cast<Args&&>(args)...);
8686
};
8787

8888
template<typename V>
8989
KERNEL_FLOAT_INLINE into_vector_type<V> into_kvec(V&& input) {
90-
return ::kernel_float::into_vec(std::forward<V>(input));
90+
return ::kernel_float::into_vec(static_cast<V&&>(input));
9191
}
9292

9393
template<typename T = double>

include/kernel_float/unops.h

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast<bool, T> {}(!ops::cast<T
123123
template<typename T> \
124124
struct NAME<T, enable_if_t<detail::allow_float_fallback<T>::value>> { \
125125
KERNEL_FLOAT_INLINE T operator()(T input_arg) { \
126-
float input = float(input_arg); \
127-
return T(EXPR_F32); \
126+
float input = ops::cast<T, float> {}(input_arg); \
127+
return ops::cast<decltype(EXPR_F32), T> {}(EXPR_F32); \
128128
} \
129129
}; \
130130
\
@@ -140,52 +140,56 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast<bool, T> {}(!ops::cast<T
140140
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(NAME, ::NAME(input), ::NAME(input)) \
141141
KERNEL_FLOAT_DEFINE_UNARY_FUN(NAME)
142142

143+
KERNEL_FLOAT_DEFINE_UNARY_MATH(sin)
144+
KERNEL_FLOAT_DEFINE_UNARY_MATH(cos)
145+
KERNEL_FLOAT_DEFINE_UNARY_MATH(tan)
146+
KERNEL_FLOAT_DEFINE_UNARY_MATH(asin)
143147
KERNEL_FLOAT_DEFINE_UNARY_MATH(acos)
144-
KERNEL_FLOAT_DEFINE_UNARY_MATH(abs)
148+
KERNEL_FLOAT_DEFINE_UNARY_MATH(atan)
149+
150+
KERNEL_FLOAT_DEFINE_UNARY_MATH(sinh)
151+
KERNEL_FLOAT_DEFINE_UNARY_MATH(cosh)
152+
KERNEL_FLOAT_DEFINE_UNARY_MATH(tanh)
145153
KERNEL_FLOAT_DEFINE_UNARY_MATH(acosh)
146-
KERNEL_FLOAT_DEFINE_UNARY_MATH(asin)
147154
KERNEL_FLOAT_DEFINE_UNARY_MATH(asinh)
148-
KERNEL_FLOAT_DEFINE_UNARY_MATH(atan)
149155
KERNEL_FLOAT_DEFINE_UNARY_MATH(atanh)
150-
KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt)
151-
KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil)
152-
KERNEL_FLOAT_DEFINE_UNARY_MATH(cos)
153-
KERNEL_FLOAT_DEFINE_UNARY_MATH(cosh)
154-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erf)
155-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfc)
156-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
157-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
158-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
156+
159157
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp)
160-
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp10)
161158
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp2)
159+
KERNEL_FLOAT_DEFINE_UNARY_MATH(exp10)
162160
KERNEL_FLOAT_DEFINE_UNARY_MATH(expm1)
163-
KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs)
164-
KERNEL_FLOAT_DEFINE_UNARY_MATH(floor)
165-
KERNEL_FLOAT_DEFINE_UNARY_MATH(ilogb)
166-
KERNEL_FLOAT_DEFINE_UNARY_MATH(lgamma)
167161
KERNEL_FLOAT_DEFINE_UNARY_MATH(log)
168-
KERNEL_FLOAT_DEFINE_UNARY_MATH(log10)
169162
KERNEL_FLOAT_DEFINE_UNARY_MATH(log2)
170-
KERNEL_FLOAT_DEFINE_UNARY_MATH(nearbyint)
163+
KERNEL_FLOAT_DEFINE_UNARY_MATH(log10)
164+
KERNEL_FLOAT_DEFINE_UNARY_MATH(log1p)
165+
166+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erf)
167+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
168+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfc)
169+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
170+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
171171
KERNEL_FLOAT_DEFINE_UNARY_MATH(normcdf)
172-
KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt)
173-
KERNEL_FLOAT_DEFINE_UNARY_MATH(sin)
174-
KERNEL_FLOAT_DEFINE_UNARY_MATH(sinh)
175-
KERNEL_FLOAT_DEFINE_UNARY_MATH(sqrt)
176-
KERNEL_FLOAT_DEFINE_UNARY_MATH(tan)
177-
KERNEL_FLOAT_DEFINE_UNARY_MATH(tanh)
172+
KERNEL_FLOAT_DEFINE_UNARY_MATH(lgamma)
178173
KERNEL_FLOAT_DEFINE_UNARY_MATH(tgamma)
179-
KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc)
180-
KERNEL_FLOAT_DEFINE_UNARY_MATH(rint)
174+
175+
KERNEL_FLOAT_DEFINE_UNARY_MATH(sqrt)
181176
KERNEL_FLOAT_DEFINE_UNARY_MATH(rsqrt)
177+
KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt)
178+
KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt)
179+
180+
KERNEL_FLOAT_DEFINE_UNARY_MATH(abs)
181+
KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs)
182+
KERNEL_FLOAT_DEFINE_UNARY_MATH(floor)
182183
KERNEL_FLOAT_DEFINE_UNARY_MATH(round)
184+
KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil)
185+
KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc)
186+
KERNEL_FLOAT_DEFINE_UNARY_MATH(rint)
183187

184188
// There are not support on HIP
185189
#if !KERNEL_FLOAT_IS_HIP
186-
KERNEL_FLOAT_DEFINE_UNARY_MATH(signbit)
187-
KERNEL_FLOAT_DEFINE_UNARY_MATH(isinf)
188190
KERNEL_FLOAT_DEFINE_UNARY_MATH(isnan)
191+
KERNEL_FLOAT_DEFINE_UNARY_MATH(isinf)
192+
KERNEL_FLOAT_DEFINE_UNARY_MATH(isfinite)
189193
#endif
190194

191195
// CUDA offers special reciprocal functions (rcp), but only on the device.

0 commit comments

Comments
 (0)