Skip to content

Commit a218c9d

Browse files
MarcelKochyhmtsai
andcommitted
review updates:
- only provide kernel for absolute sum - refactoring Co-authored-by: Yu-Hsiang M. Tsai <yhmtsai@gmail.com>
1 parent f71c9f9 commit a218c9d

File tree

7 files changed

+37
-84
lines changed

7 files changed

+37
-84
lines changed

common/unified/matrix/csr_kernels.cpp

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -313,46 +313,25 @@ void benchmark_lookup(std::shared_ptr<const DefaultExecutor> exec,
313313
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL);
314314

315315

316-
template <typename ValueType, typename IndexType, typename Closure>
317-
void row_wise_sum_impl(std::shared_ptr<const DefaultExecutor> exec,
318-
const matrix::Csr<ValueType, IndexType>* orig,
319-
array<ValueType>& sum, Closure closure)
316+
template <typename ValueType, typename IndexType>
317+
void row_wise_absolute_sum(std::shared_ptr<const DefaultExecutor> exec,
318+
const matrix::Csr<ValueType, IndexType>* orig,
319+
array<ValueType>& sum)
320320
{
321321
run_kernel(
322322
exec,
323-
[] GKO_KERNEL(auto row, auto row_ptrs, auto value_ptr, auto sum_ptr,
324-
auto closure_) {
325-
for (size_type k = row_ptrs[row];
326-
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
327-
sum_ptr[row] += closure_(value_ptr[k]);
323+
[] GKO_KERNEL(auto row, auto row_ptrs, auto value_ptr, auto sum_ptr) {
324+
sum_ptr[row] = zero<device_type<ValueType>>();
325+
for (auto k = row_ptrs[row]; k < row_ptrs[row + 1]; ++k) {
326+
sum_ptr[row] += abs(value_ptr[k]);
328327
}
329328
},
330329
sum.get_num_elems(), orig->get_const_row_ptrs(),
331-
orig->get_const_values(), sum.get_data(), closure);
332-
};
333-
334-
335-
template <typename ValueType, typename IndexType>
336-
void row_wise_sum(std::shared_ptr<const DefaultExecutor> exec,
337-
const matrix::Csr<ValueType, IndexType>* orig,
338-
array<ValueType>& sum, bool absolute)
339-
{
340-
run_kernel(
341-
exec,
342-
[] GKO_KERNEL(auto row, auto sum_ptr) {
343-
sum_ptr[row] = zero<ValueType>();
344-
},
345-
sum.get_num_elems(), sum.get_data());
346-
347-
if (absolute) {
348-
row_wise_sum_impl(exec, orig, sum,
349-
[] GKO_KERNEL(auto v) { return abs(v); });
350-
} else {
351-
row_wise_sum_impl(exec, orig, sum, [] GKO_KERNEL(auto v) { return v; });
352-
}
330+
orig->get_const_values(), sum.get_data());
353331
}
354332

355-
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_SUM);
333+
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
334+
GKO_DECLARE_CSR_ROW_WISE_ABSOLUTE_SUM);
356335

357336

358337
} // namespace csr

core/device_hooks/common_kernels.inc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(
767767
GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_OFFSETS_KERNEL);
768768
GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_KERNEL);
769769
GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL);
770-
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_SUM);
770+
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_ABSOLUTE_SUM);
771771

772772
template <typename ValueType, typename IndexType>
773773
GKO_DECLARE_CSR_SCALE_KERNEL(ValueType, IndexType)

core/distributed/preconditioner/schwarz.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace preconditioner {
3333
namespace {
3434

3535

36-
GKO_REGISTER_OPERATION(row_wise_sum, csr::row_wise_sum);
36+
GKO_REGISTER_OPERATION(row_wise_absolute_sum, csr::row_wise_absolute_sum);
3737

3838

3939
}
@@ -163,7 +163,8 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
163163

164164
array<ValueType> l1_diag_arr{exec, local_matrix->get_size()[0]};
165165

166-
exec->run(make_row_wise_sum(non_local_matrix.get(), l1_diag_arr, true));
166+
exec->run(
167+
make_row_wise_absolute_sum(non_local_matrix.get(), l1_diag_arr));
167168

168169
// compute local_matrix_copy <- diag(l1) + local_matrix_copy
169170
auto l1_diag = matrix::Diagonal<ValueType>::create(

core/matrix/csr_kernels.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,10 @@ namespace kernels {
258258
IndexType sample_size, IndexType* result)
259259

260260

261-
#define GKO_DECLARE_CSR_ROW_WISE_SUM(ValueType, IndexType) \
262-
void row_wise_sum(std::shared_ptr<const DefaultExecutor> exec, \
263-
const matrix::Csr<ValueType, IndexType>* orig, \
264-
array<ValueType>& sum, bool absolute)
261+
#define GKO_DECLARE_CSR_ROW_WISE_ABSOLUTE_SUM(ValueType, IndexType) \
262+
void row_wise_absolute_sum(std::shared_ptr<const DefaultExecutor> exec, \
263+
const matrix::Csr<ValueType, IndexType>* orig, \
264+
array<ValueType>& sum)
265265

266266

267267
#define GKO_DECLARE_ALL_AS_TEMPLATES \
@@ -344,7 +344,7 @@ namespace kernels {
344344
template <typename IndexType> \
345345
GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL(IndexType); \
346346
template <typename ValueType, typename IndexType> \
347-
GKO_DECLARE_CSR_ROW_WISE_SUM(ValueType, IndexType)
347+
GKO_DECLARE_CSR_ROW_WISE_ABSOLUTE_SUM(ValueType, IndexType)
348348

349349

350350
GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(csr, GKO_DECLARE_ALL_AS_TEMPLATES);

reference/matrix/csr_kernels.cpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,31 +1459,25 @@ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL);
14591459

14601460

14611461
template <typename ValueType, typename IndexType>
1462-
void row_wise_sum(std::shared_ptr<const DefaultExecutor> exec,
1463-
const matrix::Csr<ValueType, IndexType>* orig,
1464-
array<ValueType>& sum, bool absolute)
1462+
void row_wise_absolute_sum(std::shared_ptr<const DefaultExecutor> exec,
1463+
const matrix::Csr<ValueType, IndexType>* orig,
1464+
array<ValueType>& sum)
14651465
{
14661466
auto row_ptrs = orig->get_const_row_ptrs();
14671467
auto value_ptr = orig->get_const_values();
14681468
auto sum_ptr = sum.get_data();
14691469

1470-
auto apply = [&](auto closure) {
1471-
for (size_type row = 0; row < orig->get_size()[0]; ++row) {
1472-
sum_ptr[row] = zero<ValueType>();
1473-
for (size_type k = row_ptrs[row];
1474-
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
1475-
sum_ptr[row] += closure(value_ptr[k]);
1476-
}
1470+
for (size_type row = 0; row < orig->get_size()[0]; ++row) {
1471+
sum_ptr[row] = zero<ValueType>();
1472+
for (size_type k = row_ptrs[row];
1473+
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
1474+
sum_ptr[row] += abs(value_ptr[k]);
14771475
}
1478-
};
1479-
if (absolute) {
1480-
apply([](auto v) { return abs(v); });
1481-
} else {
1482-
apply([](auto v) { return v; });
14831476
}
14841477
}
14851478

1486-
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_WISE_SUM);
1479+
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
1480+
GKO_DECLARE_CSR_ROW_WISE_ABSOLUTE_SUM);
14871481

14881482

14891483
} // namespace csr

reference/test/matrix/csr_kernels.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,32 +2558,16 @@ TYPED_TEST(Csr, CanGetSubmatrixWithIndexSet)
25582558
}
25592559

25602560

2561-
TYPED_TEST(Csr, CanComputeRowWiseSum)
2562-
{
2563-
using value_type = typename TestFixture::value_type;
2564-
gko::array<value_type> sum(this->exec, this->mtx3_sorted->get_size()[0]);
2565-
this->create_mtx3(this->mtx3_sorted, this->mtx3_unsorted);
2566-
this->mtx3_sorted->scale(gko::initialize<gko::matrix::Dense<value_type>>(
2567-
{-gko::one<value_type>()}, this->exec));
2568-
2569-
gko::kernels::reference::csr::row_wise_sum(
2570-
this->exec, this->mtx3_sorted.get(), sum, false);
2571-
2572-
gko::array<value_type> sum_result(this->exec, {-3, -12, -5});
2573-
GKO_ASSERT_ARRAY_EQ(sum, sum_result);
2574-
}
2575-
2576-
25772561
TYPED_TEST(Csr, CanComputeRowWiseAbsoluteSum)
25782562
{
25792563
using value_type = typename TestFixture::value_type;
25802564
gko::array<value_type> sum(this->exec, this->mtx3_sorted->get_size()[0]);
2581-
this->create_mtx3(this->mtx3_sorted, this->mtx3_unsorted);
2565+
this->create_mtx3(this->mtx3_sorted.get(), this->mtx3_unsorted.get());
25822566
this->mtx3_sorted->scale(gko::initialize<gko::matrix::Dense<value_type>>(
25832567
{-gko::one<value_type>()}, this->exec));
25842568

2585-
gko::kernels::reference::csr::row_wise_sum(
2586-
this->exec, this->mtx3_sorted.get(), sum, true);
2569+
gko::kernels::reference::csr::row_wise_absolute_sum(
2570+
this->exec, this->mtx3_sorted.get(), sum);
25872571

25882572
gko::array<value_type> sum_result(this->exec, {3, 12, 5});
25892573
GKO_ASSERT_ARRAY_EQ(sum, sum_result);

test/matrix/csr_kernels.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,11 @@ TEST_F(Csr, RowWiseSumIsEquivalentToRef)
8484
gko::array<value_type> sum{ref, x->get_size()[0]};
8585
gko::array<value_type> dsum{exec, dx->get_size()[0]};
8686

87-
for (auto use_absolute : {false, true}) {
88-
SCOPED_TRACE(use_absolute ? "With absolute" : "Without absolute");
87+
gko::kernels::reference::csr::row_wise_absolute_sum(ref, x.get(), sum);
88+
gko::kernels::GKO_DEVICE_NAMESPACE::csr::row_wise_absolute_sum(
89+
exec, dx.get(), dsum);
8990

90-
gko::kernels::reference::csr::row_wise_sum(ref, x.get(), sum,
91-
use_absolute);
92-
gko::kernels::EXEC_NAMESPACE::csr::row_wise_sum(exec, dx.get(), dsum,
93-
use_absolute);
94-
95-
GKO_ASSERT_ARRAY_EQ(sum, dsum);
96-
}
91+
GKO_ASSERT_ARRAY_EQ(sum, dsum);
9792
}
9893

9994

0 commit comments

Comments
 (0)