Skip to content

Commit 980d0b6

Browse files
committed
add kernels for row-wise (absolute) sum
Signed-off-by: Marcel Koch <marcel.koch@kit.edu>
1 parent 2acbcdd commit 980d0b6

File tree

6 files changed

+133
-4
lines changed

6 files changed

+133
-4
lines changed

common/unified/matrix/csr_kernels.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -313,6 +313,48 @@ void benchmark_lookup(std::shared_ptr<const DefaultExecutor> exec,
313313
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL);
314314

315315

316+
template <typename ValueType, typename IndexType, typename Closure>
317+
void row_wise_sum_impl(std::shared_ptr<const DefaultExecutor> exec,
318+
const matrix::Csr<ValueType, IndexType>* orig,
319+
array<ValueType>& sum, Closure closure)
320+
{
321+
run_kernel(
322+
exec,
323+
[] GKO_KERNEL(auto row, auto row_ptrs, auto value_ptr, auto sum_ptr,
324+
auto closure_) {
325+
for (size_type k = row_ptrs[row];
326+
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
327+
sum_ptr[row] += closure_(value_ptr[k]);
328+
}
329+
},
330+
sum.get_num_elems(), orig->get_const_row_ptrs(),
331+
orig->get_const_values(), sum.get_data(), closure);
332+
};
333+
334+
335+
template <typename ValueType, typename IndexType>
336+
void row_wise_sum(std::shared_ptr<const DefaultExecutor> exec,
337+
const matrix::Csr<ValueType, IndexType>* orig,
338+
array<ValueType>& sum, bool absolute)
339+
{
340+
run_kernel(
341+
exec,
342+
[] GKO_KERNEL(auto row, auto sum_ptr) {
343+
sum_ptr[row] = zero<ValueType>();
344+
},
345+
sum.get_num_elems(), sum.get_data());
346+
347+
if (absolute) {
348+
row_wise_sum_impl(exec, orig, sum,
349+
[] GKO_KERNEL(auto v) { return abs(v); });
350+
} else {
351+
row_wise_sum_impl(exec, orig, sum, [] GKO_KERNEL(auto v) { return v; });
352+
}
353+
}
354+
355+
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_SUM);
356+
357+
316358
} // namespace csr
317359
} // namespace GKO_DEVICE_NAMESPACE
318360
} // namespace kernels

core/device_hooks/common_kernels.inc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(
767767
GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_OFFSETS_KERNEL);
768768
GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_KERNEL);
769769
GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL);
770+
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_SUM);
770771

771772
template <typename ValueType, typename IndexType>
772773
GKO_DECLARE_CSR_SCALE_KERNEL(ValueType, IndexType)

core/matrix/csr_kernels.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -258,6 +258,12 @@ namespace kernels {
258258
IndexType sample_size, IndexType* result)
259259

260260

261+
#define GKO_DECLARE_CSR_ROW_WISE_SUM(ValueType, IndexType) \
262+
void row_wise_sum(std::shared_ptr<const DefaultExecutor> exec, \
263+
const matrix::Csr<ValueType, IndexType>* orig, \
264+
array<ValueType>& sum, bool absolute)
265+
266+
261267
#define GKO_DECLARE_ALL_AS_TEMPLATES \
262268
template <typename MatrixValueType, typename InputValueType, \
263269
typename OutputValueType, typename IndexType> \
@@ -336,7 +342,9 @@ namespace kernels {
336342
template <typename IndexType> \
337343
GKO_DECLARE_CSR_BUILD_LOOKUP_KERNEL(IndexType); \
338344
template <typename IndexType> \
339-
GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL(IndexType)
345+
GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL(IndexType); \
346+
template <typename ValueType, typename IndexType> \
347+
GKO_DECLARE_CSR_ROW_WISE_SUM(ValueType, IndexType)
340348

341349

342350
GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(csr, GKO_DECLARE_ALL_AS_TEMPLATES);

reference/matrix/csr_kernels.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,6 +1458,34 @@ void benchmark_lookup(std::shared_ptr<const DefaultExecutor> exec,
14581458
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL);
14591459

14601460

1461+
template <typename ValueType, typename IndexType>
1462+
void row_wise_sum(std::shared_ptr<const DefaultExecutor> exec,
1463+
const matrix::Csr<ValueType, IndexType>* orig,
1464+
array<ValueType>& sum, bool absolute)
1465+
{
1466+
auto row_ptrs = orig->get_const_row_ptrs();
1467+
auto value_ptr = orig->get_const_values();
1468+
auto sum_ptr = sum.get_data();
1469+
1470+
auto apply = [&](auto closure) {
1471+
for (size_type row = 0; row < orig->get_size()[0]; ++row) {
1472+
sum_ptr[row] = zero<ValueType>();
1473+
for (size_type k = row_ptrs[row];
1474+
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
1475+
sum_ptr[row] += closure(value_ptr[k]);
1476+
}
1477+
}
1478+
};
1479+
if (absolute) {
1480+
apply([](auto v) { return abs(v); });
1481+
} else {
1482+
apply([](auto v) { return v; });
1483+
}
1484+
}
1485+
1486+
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_SUM);
1487+
1488+
14611489
} // namespace csr
14621490
} // namespace reference
14631491
} // namespace kernels

reference/test/matrix/csr_kernels.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,6 +2558,37 @@ TYPED_TEST(Csr, CanGetSubmatrixWithIndexSet)
25582558
}
25592559

25602560

2561+
TYPED_TEST(Csr, CanComputeRowWiseSum)
2562+
{
2563+
using value_type = typename TestFixture::value_type;
2564+
gko::array<value_type> sum(this->exec, this->mtx3_sorted->get_size()[0]);
2565+
this->create_mtx3(this->mtx3_sorted, this->mtx3_unsorted);
2566+
this->mtx3_sorted->scale(gko::initialize<gko::matrix::Dense<value_type>>(
2567+
{-gko::one<value_type>()}, this->exec));
2568+
2569+
gko::kernels::reference::csr::row_wise_sum(
2570+
this->exec, this->mtx3_sorted.get(), sum, false);
2571+
2572+
gko::array<value_type> sum_result(this->exec, {-3, -12, -5});
2573+
GKO_ASSERT_ARRAY_EQ(sum, sum_result);
2574+
}
2575+
2576+
2577+
TYPED_TEST(Csr, CanComputeRowWiseAbsoluteSum)
2578+
{
2579+
using value_type = typename TestFixture::value_type;
2580+
gko::array<value_type> sum(this->exec, this->mtx3_sorted->get_size()[0]);
2581+
this->create_mtx3(this->mtx3_sorted, this->mtx3_unsorted);
2582+
this->mtx3_sorted->scale(gko::initialize<gko::matrix::Dense<value_type>>(
2583+
{-gko::one<value_type>()}, this->exec));
2584+
2585+
gko::kernels::reference::csr::row_wise_sum(
2586+
this->exec, this->mtx3_sorted.get(), sum, true);
2587+
2588+
gko::array<value_type> sum_result(this->exec, {3, 12, 5});
2589+
GKO_ASSERT_ARRAY_EQ(sum, sum_result);
2590+
}
2591+
25612592
template <typename ValueIndexType>
25622593
class CsrLookup : public ::testing::Test {
25632594
protected:

test/matrix/csr_kernels.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -78,6 +78,25 @@ TEST_F(Csr, InvScaleIsEquivalentToRef)
7878
}
7979

8080

81+
TEST_F(Csr, RowWiseSumIsEquivalentToRef)
82+
{
83+
set_up_apply_data();
84+
gko::array<value_type> sum{ref, x->get_size()[0]};
85+
gko::array<value_type> dsum{exec, dx->get_size()[0]};
86+
87+
for (auto use_absolute : {false, true}) {
88+
SCOPED_TRACE(use_absolute ? "With absolute" : "Without absolute");
89+
90+
gko::kernels::reference::csr::row_wise_sum(ref, x.get(), sum,
91+
use_absolute);
92+
gko::kernels::EXEC_NAMESPACE::csr::row_wise_sum(exec, dx.get(), dsum,
93+
use_absolute);
94+
95+
GKO_ASSERT_ARRAY_EQ(sum, dsum);
96+
}
97+
}
98+
99+
81100
template <typename IndexType>
82101
class CsrLookup : public CommonTestFixture {
83102
public:

0 commit comments

Comments
 (0)