Skip to content

Commit 5208898

Browse files
committed
Merge branch 'dev'
2 parents 17c83b6 + c0c7d10 commit 5208898

File tree

23 files changed

+1318
-598
lines changed

23 files changed

+1318
-598
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ API Reference
33
.. toctree::
44
api/types.rst
55
api/primitives.rst
6+
api/conversion.rst
67
api/generation.rst
78
api/unary_operators.rst
89
api/binary_operators.rst

docs/build_api.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,18 @@ def build_index_page(groups):
8383
"reduce",
8484
"zip",
8585
"zip_common",
86-
"cast",
87-
"broadcast",
88-
"convert",
8986
"make_vec",
9087
"into_vec",
9188
"concat",
9289
"select",
9390
"for_each",
9491
],
92+
"Conversion": [
93+
"convert",
94+
"cast",
95+
"cast_to",
96+
"broadcast",
97+
],
9598
"Generation": [
9699
("range", "range()"),
97100
("range", "range(F fun)"),
@@ -186,13 +189,14 @@ def build_index_page(groups):
186189
"sin",
187190
"sinh",
188191
("sqrt", "sqrt(const V&)"),
192+
"rsqrt",
189193
"tan",
190194
"tanh",
191195
"tgamma",
192-
"trunc",
196+
"rcp",
193197
"rint",
194-
"rsqrt",
195198
"round",
199+
"trunc",
196200
"signbit",
197201
"isinf",
198202
"isnan",
@@ -203,6 +207,9 @@ def build_index_page(groups):
203207
"fast_cos",
204208
"fast_sin",
205209
"fast_tan",
210+
"fast_rcp",
211+
"fast_sqrt",
212+
"fast_rsqrt",
206213
"fast_div",
207214
],
208215
"Conditional": [
@@ -211,7 +218,6 @@ def build_index_page(groups):
211218
("where", "where(const C&)"),
212219
],
213220
"Memory read/write": [
214-
"cast_to",
215221
("read", "read(const T*, const I&, const M&)"),
216222
("write", "write(T*, const I&, const V&, const M&)"),
217223

@@ -220,8 +226,9 @@ def build_index_page(groups):
220226

221227
("read_aligned", "read_aligned(const T*)"),
222228
("write_aligned", "write_aligned(T*, const V&)"),
229+
"assert_aligned",
223230

224-
("aligned_ptr", "aligned_ptr", "struct"),
231+
("vector_ptr", "vector_ptr", "struct"),
225232
],
226233
"Utilities": [
227234
("constant", "constant", "struct"),

examples/vector_add/main.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ void cuda_check(cudaError_t code) {
1313
}
1414

1515
template<int N>
16-
__global__ void my_kernel(int length, const __half* input, double constant, float* output) {
16+
__global__ void my_kernel(
17+
int length,
18+
kf::vec_ptr<const half, N> input,
19+
double constant,
20+
kf::vec_ptr<half, N, float> output) {
1721
int i = blockIdx.x * blockDim.x + threadIdx.x;
1822

1923
if (i * N < length) {
20-
auto a = kf::read_aligned<N>(input + i * N);
21-
auto b = kf::fma(a, a, kf::cast<__half>(constant));
22-
kf::write_aligned<N>(output + i * N, b);
24+
output(i) = kf::fma(input[i], input[i], kf::cast<__half>(constant));
2325
}
2426
}
2527

@@ -51,9 +53,9 @@ void run_kernel(int n) {
5153
int grid_size = (n + items_per_block - 1) / items_per_block;
5254
my_kernel<items_per_thread><<<grid_size, block_size>>>(
5355
n,
54-
kf::aligned_ptr(input_dev),
56+
kf::assert_aligned(input_dev),
5557
constant,
56-
kf::aligned_ptr(output_dev));
58+
kf::assert_aligned(output_dev));
5759

5860
// Copy results back
5961
cuda_check(cudaMemcpy(output_dev, output_result.data(), sizeof(float) * n, cudaMemcpyDefault));
@@ -80,7 +82,7 @@ int main() {
8082

8183
run_kernel<1>(n);
8284
run_kernel<2>(n);
83-
run_kernel<3>(n);
85+
// run_kernel<3>(n);
8486
run_kernel<4>(n);
8587
run_kernel<8>(n);
8688

examples/vector_add_tiling/main.cu

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ void cuda_check(cudaError_t code) {
1414
}
1515

1616
template<int N, int B>
17-
__global__ void my_kernel(
18-
int length,
19-
kf::aligned_ptr<const __half> input,
20-
double constant,
21-
kf::aligned_ptr<float> output) {
17+
__global__ void my_kernel(int length, const __half* input, double constant, float* output) {
2218
auto tiling = kf::tiling<
2319
kf::tile_factor<N>,
2420
kf::block_size<B>,
@@ -27,9 +23,9 @@ __global__ void my_kernel(
2723
auto points = int(blockIdx.x * tiling.tile_size(0)) + tiling.local_points(0);
2824
auto mask = tiling.local_mask();
2925

30-
auto a = input.read(points, mask);
26+
auto a = kf::read(input, points, mask);
3127
auto b = (a * a) * constant;
32-
output.write(points, b, mask);
28+
kf::write(output, points, b, mask);
3329
}
3430

3531
template<int items_per_thread, int block_size = 256>
@@ -57,11 +53,8 @@ void run_kernel(int n) {
5753
// Launch kernel!
5854
int items_per_block = block_size * items_per_thread;
5955
int grid_size = (n + items_per_block - 1) / items_per_block;
60-
my_kernel<items_per_thread, block_size><<<grid_size, block_size>>>(
61-
n,
62-
kf::aligned_ptr(input_dev),
63-
constant,
64-
kf::aligned_ptr(output_dev));
56+
my_kernel<items_per_thread, block_size>
57+
<<<grid_size, block_size>>>(n, input_dev, constant, output_dev);
6558

6659
// Copy results back
6760
cuda_check(cudaMemcpy(output_dev, output_result.data(), sizeof(float) * n, cudaMemcpyDefault));

include/kernel_float/apply.h

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ struct apply_recur_impl<1> {
152152
result[0] = fun(inputs[0]...);
153153
}
154154
};
155+
156+
template<typename F, size_t N, typename Output, typename... Args>
157+
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
155158
} // namespace detail
156159

157160
template<typename F, typename... Args>
@@ -174,7 +177,34 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
174177
using E = broadcast_vector_extent_type<Args...>;
175178
vector_storage<Output, E::value> result;
176179

177-
detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>::call(
180+
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
181+
#if KERNEL_FLOAT_FAST_MATH
182+
using apply_impl = detail::apply_fastmath_impl<F, E::value, Output, vector_value_type<Args>...>;
183+
#else
184+
using apply_impl = detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>;
185+
#endif
186+
187+
apply_impl::call(
188+
fun,
189+
result.data(),
190+
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
191+
into_vector_storage(args))
192+
.data())...);
193+
194+
return result;
195+
}
196+
197+
/**
198+
* Apply the function `F` to each element from the vector `input` and return the results as a new vector. This
199+
* uses fast-math if available for the given function `F`, otherwise this function behaves like `map`.
200+
*/
201+
template<typename F, typename... Args>
202+
KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map(F fun, const Args&... args) {
203+
using Output = result_t<F, vector_value_type<Args>...>;
204+
using E = broadcast_vector_extent_type<Args...>;
205+
vector_storage<Output, E::value> result;
206+
207+
detail::apply_fastmath_impl<F, E::value, Output, vector_value_type<Args>...>::call(
178208
fun,
179209
result.data(),
180210
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(

include/kernel_float/base.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ struct extent<N> {
8989
static constexpr size_t size = N;
9090
};
9191

92+
namespace detail {
93+
// Indicates that elements of type `T` offer less precision than floats, thus operations
94+
// on elements of type `T` can be performed by upcasting them to ` float`.
95+
template<typename T>
96+
struct allow_float_fallback {
97+
static constexpr bool value = false;
98+
};
99+
100+
template<>
101+
struct allow_float_fallback<float> {
102+
static constexpr bool value = true;
103+
};
104+
} // namespace detail
105+
92106
template<typename T>
93107
struct into_vector_impl {
94108
using value_type = T;

include/kernel_float/bf16.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
7272
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
7373
KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt)
7474
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
75-
76-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_exp, ::hexp, ::h2exp)
77-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_log, ::hlog, ::h2log)
78-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_cos, ::hcos, ::h2cos)
79-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
75+
KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
8076
#endif
8177

8278
#if KERNEL_FLOAT_CUDA_ARCH >= 800
@@ -114,10 +110,8 @@ KERNEL_FLOAT_BF16_BINARY_FUN(divide, __hdiv, __h2div)
114110
KERNEL_FLOAT_BF16_BINARY_FUN(min, __hmin, __hmin2)
115111
KERNEL_FLOAT_BF16_BINARY_FUN(max, __hmax, __hmax2)
116112

117-
KERNEL_FLOAT_BF16_BINARY_FUN(fast_div, __hdiv, __h2div)
118-
119113
KERNEL_FLOAT_BF16_BINARY_FUN(equal_to, __heq, __heq2)
120-
KERNEL_FLOAT_BF16_BINARY_FUN(not_equal_to, __heq, __heq2)
114+
KERNEL_FLOAT_BF16_BINARY_FUN(not_equal_to, __hneu, __hneu2)
121115
KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2)
122116
KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2)
123117
KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)

include/kernel_float/binops.h

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
namespace kernel_float {
88

99
template<typename F, typename L, typename R>
10-
using zip_type = vector<
11-
result_t<F, vector_value_type<L>, vector_value_type<R>>,
12-
broadcast_vector_extent_type<L, R>>;
10+
using zip_type = map_type<F, L, R>;
1311

1412
/**
1513
* Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary
@@ -25,20 +23,7 @@ using zip_type = vector<
2523
*/
2624
template<typename F, typename L, typename R>
2725
KERNEL_FLOAT_INLINE zip_type<F, L, R> zip(F fun, const L& left, const R& right) {
28-
using A = vector_value_type<L>;
29-
using B = vector_value_type<R>;
30-
using O = result_t<F, A, B>;
31-
using E = broadcast_vector_extent_type<L, R>;
32-
vector_storage<O, E::value> result;
33-
34-
detail::apply_impl<F, E::value, O, A, B>::call(
35-
fun,
36-
result.data(),
37-
detail::broadcast_impl<A, vector_extent_type<L>, E>::call(into_vector_storage(left)).data(),
38-
detail::broadcast_impl<B, vector_extent_type<R>, E>::call(into_vector_storage(right))
39-
.data());
40-
41-
return result;
26+
return ::kernel_float::map(fun, left, right);
4227
}
4328

4429
template<typename F, typename L, typename R>
@@ -67,7 +52,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
6752

6853
vector_storage<O, E::value> result;
6954

70-
detail::apply_impl<F, E::value, O, T, T>::call(
55+
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
56+
#if KERNEL_FLOAT_FAST_MATH
57+
using apply_impl = detail::apply_fastmath_impl<F, E::value, O, T, T>;
58+
#else
59+
using apply_impl = detail::apply_impl<F, E::value, O, T, T>;
60+
#endif
61+
62+
apply_impl::call(
7163
fun,
7264
result.data(),
7365
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
@@ -277,36 +269,17 @@ KERNEL_FLOAT_DEFINE_BINARY(
277269
#if KERNEL_FLOAT_IS_DEVICE
278270
KERNEL_FLOAT_DEFINE_BINARY(
279271
rhypot,
280-
(T(1) / ops::hypot<T>()(left, right)),
272+
(ops::rcp<T>(ops::hypot<T>()(left, right))),
281273
::rhypot(left, right),
282274
::rhypotf(left, right))
283275
#else
284276
KERNEL_FLOAT_DEFINE_BINARY(
285277
rhypot,
286-
(T(1) / ops::hypot<T>()(left, right)),
278+
(ops::rcp<T>(ops::hypot<T>()(left, right))),
287279
(double(1) / ::hypot(left, right)),
288280
(float(1) / ::hypotf(left, right)))
289281
#endif
290282

291-
#if KERNEL_FLOAT_IS_DEVICE
292-
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
293-
KERNEL_FLOAT_DEFINE_BINARY( \
294-
FUN_NAME, \
295-
ops::OP_NAME<T> {}(left, right), \
296-
ops::OP_NAME<double> {}(left, right), \
297-
ops::OP_NAME<float> {}(left, right))
298-
#else
299-
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
300-
KERNEL_FLOAT_DEFINE_BINARY( \
301-
FUN_NAME, \
302-
ops::OP_NAME<T> {}(left, right), \
303-
ops::OP_NAME<double> {}(left, right), \
304-
ops::OP_NAME<float> {}(left, right))
305-
#endif
306-
307-
KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_div, divide, __fdividef)
308-
KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_pow, pow, __powf)
309-
310283
namespace ops {
311284
template<>
312285
struct add<bool> {
@@ -323,6 +296,52 @@ struct multiply<bool> {
323296
};
324297
}; // namespace ops
325298

299+
namespace detail {
300+
template<typename T, size_t N>
301+
struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
302+
KERNEL_FLOAT_INLINE static void
303+
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
304+
T rhs_rcp[N];
305+
306+
// Fast way to perform division is to multiply by the reciprocal
307+
apply_fastmath_impl<ops::rcp<T>, N, T, T, T>::call({}, rhs_rcp, rhs);
308+
apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
309+
}
310+
};
311+
312+
#if KERNEL_FLOAT_IS_DEVICE
313+
template<size_t N>
314+
struct apply_fastmath_impl<ops::divide<float>, N, float, float, float> {
315+
KERNEL_FLOAT_INLINE static void
316+
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
317+
#pragma unroll
318+
for (size_t i = 0; i < N; i++) {
319+
result[i] = __fdividef(lhs[i], rhs[i]);
320+
}
321+
}
322+
};
323+
#endif
324+
} // namespace detail
325+
326+
template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
327+
KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
328+
fast_divide(const L& left, const R& right) {
329+
using E = broadcast_vector_extent_type<L, R>;
330+
vector_storage<T, E::value> result;
331+
332+
detail::apply_fastmath_impl<ops::divide<T>, E::value, T, T, T>::call(
333+
ops::divide<T> {},
334+
result.data(),
335+
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
336+
into_vector_storage(left))
337+
.data(),
338+
detail::convert_impl<vector_value_type<R>, vector_extent_type<R>, T, E>::call(
339+
into_vector_storage(right))
340+
.data());
341+
342+
return result;
343+
}
344+
326345
namespace detail {
327346
template<typename T>
328347
struct cross_impl {

0 commit comments

Comments
 (0)