Skip to content

Commit 1611258

Browse files
committed
Remove apply_fallback_impl
1 parent 2730789 commit 1611258

File tree

1 file changed

+30
-39
lines changed

1 file changed

+30
-39
lines changed

include/kernel_float/apply.h

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,41 @@ namespace kernel_float {
77
namespace detail {
88

99
template<typename... Es>
10-
struct broadcast_extent_helper;
10+
struct broadcast_extent_impl;
1111

1212
template<typename E>
13-
struct broadcast_extent_helper<E> {
13+
struct broadcast_extent_impl<E> {
1414
using type = E;
1515
};
1616

1717
template<size_t N>
18-
struct broadcast_extent_helper<extent<N>, extent<N>> {
18+
struct broadcast_extent_impl<extent<N>, extent<N>> {
1919
using type = extent<N>;
2020
};
2121

2222
template<size_t N>
23-
struct broadcast_extent_helper<extent<1>, extent<N>> {
23+
struct broadcast_extent_impl<extent<1>, extent<N>> {
2424
using type = extent<N>;
2525
};
2626

2727
template<size_t N>
28-
struct broadcast_extent_helper<extent<N>, extent<1>> {
28+
struct broadcast_extent_impl<extent<N>, extent<1>> {
2929
using type = extent<N>;
3030
};
3131

3232
template<>
33-
struct broadcast_extent_helper<extent<1>, extent<1>> {
33+
struct broadcast_extent_impl<extent<1>, extent<1>> {
3434
using type = extent<1>;
3535
};
3636

3737
template<typename A, typename B, typename C, typename... Rest>
38-
struct broadcast_extent_helper<A, B, C, Rest...>:
39-
broadcast_extent_helper<typename broadcast_extent_helper<A, B>::type, C, Rest...> {};
38+
struct broadcast_extent_impl<A, B, C, Rest...>:
39+
broadcast_extent_impl<typename broadcast_extent_impl<A, B>::type, C, Rest...> {};
4040

4141
} // namespace detail
4242

4343
template<typename... Es>
44-
using broadcast_extent = typename detail::broadcast_extent_helper<Es...>::type;
44+
using broadcast_extent = typename detail::broadcast_extent_impl<Es...>::type;
4545

4646
template<typename... Vs>
4747
using broadcast_vector_extent_type = broadcast_extent<vector_extent_type<Vs>...>;
@@ -128,15 +128,24 @@ struct accurate_policy {};
128128
* the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve
129129
* approximations that slightly compromise precision.
130130
*/
131-
struct fast_policy {};
131+
struct fast_policy {
132+
using fallback_policy = accurate_policy;
133+
};
132134

133135
/**
134136
* This template policy allows developers to specify a custom degree of approximation for their computations. By
135137
* adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the
136138
* specific needs of your application. Higher values mean more precision.
137139
*/
138140
template<int Level = -1>
139-
struct approx_level_policy {};
141+
struct approx_level_policy {
142+
using fallback_policy = approx_level_policy<>;
143+
};
144+
145+
template<>
146+
struct approx_level_policy<> {
147+
using fallback_policy = fast_policy;
148+
};
140149

141150
/**
142151
* The approximate_policy serves as the default approximation policy, providing a standard level of approximation
@@ -145,15 +154,17 @@ struct approx_level_policy {};
145154
*/
146155
using approx_policy = approx_level_policy<>;
147156

148-
#ifndef KERNEL_FLOAT_POLICY
149-
#define KERNEL_FLOAT_POLICY accurate_policy
150-
#endif
151-
152157
/**
153158
* The `default_policy` acts as the standard computation policy. It can be configured externally using the
154-
* `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`.
159+
* `KERNEL_FLOAT_GLOBAL_POLICY` macro. If `KERNEL_FLOAT_GLOBAL_POLICY` is not defined, default to `accurate_policy`.
155160
*/
161+
#if defined(KERNEL_FLOAT_GLOBAL_POLICY)
162+
using default_policy = KERNEL_FLOAT_GLOBAL_POLICY;
163+
#elif defined(KERNEL_FLOAT_POLICY)
156164
using default_policy = KERNEL_FLOAT_POLICY;
165+
#else
166+
using default_policy = accurate_policy;
167+
#endif
157168

158169
namespace detail {
159170

@@ -164,35 +175,15 @@ struct invoke_impl {
164175
}
165176
};
166177

167-
//
168178
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
169-
struct apply_fallback_impl {
170-
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
171-
static_assert(N > 0, "operation not implemented");
172-
}
173-
};
179+
struct apply_impl;
174180

175181
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
176-
struct apply_base_impl: apply_fallback_impl<Policy, F, N, Output, Args...> {};
182+
struct apply_base_impl: apply_impl<typename Policy::fallback_policy, F, N, Output, Args...> {};
177183

178184
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
179185
struct apply_impl: apply_base_impl<Policy, F, N, Output, Args...> {};
180186

181-
// `fast_policy` falls back to `accurate_policy`
182-
template<typename F, size_t N, typename Output, typename... Args>
183-
struct apply_fallback_impl<fast_policy, F, N, Output, Args...>:
184-
apply_impl<accurate_policy, F, N, Output, Args...> {};
185-
186-
// `approx_policy` falls back to `fast_policy`
187-
template<typename F, size_t N, typename Output, typename... Args>
188-
struct apply_fallback_impl<approx_policy, F, N, Output, Args...>:
189-
apply_impl<fast_policy, F, N, Output, Args...> {};
190-
191-
// `approx_level_policy` falls back to `approx_policy`
192-
template<int Level, typename F, size_t N, typename Output, typename... Args>
193-
struct apply_fallback_impl<approx_level_policy<Level>, F, N, Output, Args...>:
194-
apply_impl<approx_policy, F, N, Output, Args...> {};
195-
196187
// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
197188
template<typename F, size_t N, typename Output, typename... Args>
198189
struct apply_impl<accurate_policy, F, N, Output, Args...> {
@@ -266,4 +257,4 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
266257

267258
} // namespace kernel_float
268259

269-
#endif // KERNEL_FLOAT_APPLY_H
260+
#endif // KERNEL_FLOAT_APPLY_H

0 commit comments

Comments
 (0)