Skip to content

Commit 014e32f

Browse files
committed
Implement approximation for pow
1 parent f89cf98 commit 014e32f

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

include/kernel_float/binops.h

+22-3
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,7 @@ struct multiply<bool> {
292292
namespace detail {
293293
template<typename Policy, typename T, size_t N>
294294
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
295-
KERNEL_FLOAT_INLINE static void
296-
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
295+
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
297296
T rhs_rcp[N];
298297

299298
// Fast way to perform division is to multiply by the reciprocal
@@ -310,13 +309,33 @@ struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
310309
template<>
311310
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
312311
KERNEL_FLOAT_INLINE static void
313-
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
312+
call(ops::divide<float>, float* result, const float* lhs, const float* rhs) {
314313
*result = __fdividef(*lhs, *rhs);
315314
}
316315
};
317316
#endif
318317
} // namespace detail
319318

319+
namespace detail {
320+
// Override `pow` using `log2` and `exp2`
321+
template<typename Policy, typename T, size_t N>
322+
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
323+
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
324+
T lhs_log[N];
325+
T result_log[N];
326+
327+
// Fast way to perform power function is using log2 and exp2
328+
apply_impl<Policy, ops::log2<T>, N, T, T>::call({}, lhs_log, lhs);
329+
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, result_log, lhs_log, rhs);
330+
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
331+
}
332+
};
333+
334+
template<typename T, size_t N>
335+
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
336+
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
337+
} // namespace detail
338+
320339
template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
321340
KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
322341
fast_divide(const L& left, const R& right) {

single_include/kernel_float.h

+24-5
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-11-18 13:40:03.668017
20-
// git hash: ae0e6b16ac2d626e69bb08554044a77671f408ab
19+
// date: 2024-11-18 13:50:24.614671
20+
// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -1950,8 +1950,7 @@ struct multiply<bool> {
19501950
namespace detail {
19511951
template<typename Policy, typename T, size_t N>
19521952
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
1953-
KERNEL_FLOAT_INLINE static void
1954-
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
1953+
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
19551954
T rhs_rcp[N];
19561955

19571956
// Fast way to perform division is to multiply by the reciprocal
@@ -1968,13 +1967,33 @@ struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
19681967
template<>
19691968
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
19701969
KERNEL_FLOAT_INLINE static void
1971-
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
1970+
call(ops::divide<float>, float* result, const float* lhs, const float* rhs) {
19721971
*result = __fdividef(*lhs, *rhs);
19731972
}
19741973
};
19751974
#endif
19761975
} // namespace detail
19771976

1977+
namespace detail {
1978+
// Override `pow` using `log2` and `exp2`
1979+
template<typename Policy, typename T, size_t N>
1980+
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
1981+
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
1982+
T lhs_log[N];
1983+
T result_log[N];
1984+
1985+
// Fast way to perform power function is using log2 and exp2
1986+
apply_impl<Policy, ops::log2<T>, N, T, T>::call({}, lhs_log, lhs);
1987+
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, result_log, lhs_log, rhs);
1988+
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
1989+
}
1990+
};
1991+
1992+
template<typename T, size_t N>
1993+
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
1994+
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
1995+
} // namespace detail
1996+
19781997
template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
19791998
KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
19801999
fast_divide(const L& left, const R& right) {

0 commit comments

Comments
 (0)