Skip to content

Commit fb83400

Browse files
yhmtsaiMarcelKoch
andcommitted
update format, make data private for safety, and use early return for simpler structure
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
1 parent 6a9cf14 commit fb83400

File tree

3 files changed

+75
-67
lines changed

3 files changed

+75
-67
lines changed

common/cuda_hip/matrix/coo_kernels.cpp

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -266,38 +266,39 @@ void spmv2(std::shared_ptr<const DefaultExecutor> exec,
266266
const auto b_ncols = b->get_size()[1];
267267
const dim3 coo_block(config::warp_size, warps_in_block, 1);
268268
const auto nwarps = host_kernel::calculate_nwarps(exec, nnz);
269-
if (nwarps > 0 && b_ncols > 0) {
270-
// not support 16 bit atomic
269+
if (nwarps <= 0 && b_ncols <= 0) {
270+
return;
271+
}
272+
// not support 16 bit atomic
271273
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
272-
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
273-
GKO_NOT_SUPPORTED(c);
274-
} else
274+
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
275+
GKO_NOT_SUPPORTED(c);
276+
} else
275277
#endif
276-
{
277-
// TODO: b_ncols needs to be tuned for ROCm.
278-
if (b_ncols < 4) {
279-
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
280-
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
281-
282-
abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
283-
nnz, num_lines, as_device_type(a->get_const_values()),
284-
a->get_const_col_idxs(),
285-
as_device_type(a->get_const_row_idxs()),
286-
as_device_type(b->get_const_values()), b->get_stride(),
287-
as_device_type(c->get_values()), c->get_stride());
288-
} else {
289-
int num_elems = ceildiv(nnz, nwarps * config::warp_size) *
290-
config::warp_size;
291-
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
292-
ceildiv(b_ncols, config::warp_size));
293-
294-
abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
295-
nnz, num_elems, as_device_type(a->get_const_values()),
296-
a->get_const_col_idxs(),
297-
as_device_type(a->get_const_row_idxs()), b_ncols,
298-
as_device_type(b->get_const_values()), b->get_stride(),
299-
as_device_type(c->get_values()), c->get_stride());
300-
}
278+
{
279+
// TODO: b_ncols needs to be tuned for ROCm.
280+
if (b_ncols < 4) {
281+
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
282+
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
283+
284+
abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
285+
nnz, num_lines, as_device_type(a->get_const_values()),
286+
a->get_const_col_idxs(),
287+
as_device_type(a->get_const_row_idxs()),
288+
as_device_type(b->get_const_values()), b->get_stride(),
289+
as_device_type(c->get_values()), c->get_stride());
290+
} else {
291+
int num_elems =
292+
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
293+
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
294+
ceildiv(b_ncols, config::warp_size));
295+
296+
abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
297+
nnz, num_elems, as_device_type(a->get_const_values()),
298+
a->get_const_col_idxs(),
299+
as_device_type(a->get_const_row_idxs()), b_ncols,
300+
as_device_type(b->get_const_values()), b->get_stride(),
301+
as_device_type(c->get_values()), c->get_stride());
301302
}
302303
}
303304
}
@@ -317,40 +318,39 @@ void advanced_spmv2(std::shared_ptr<const DefaultExecutor> exec,
317318
const dim3 coo_block(config::warp_size, warps_in_block, 1);
318319
const auto b_ncols = b->get_size()[1];
319320

320-
if (nwarps > 0 && b_ncols > 0) {
321-
// not support 16 bit atomic
321+
if (nwarps <= 0 && b_ncols <= 0) {
322+
return;
323+
}
324+
// not support 16 bit atomic
322325
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
323-
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
324-
GKO_NOT_SUPPORTED(c);
325-
} else
326+
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
327+
GKO_NOT_SUPPORTED(c);
328+
} else
326329
#endif
327-
{
328-
// TODO: b_ncols needs to be tuned for ROCm.
329-
if (b_ncols < 4) {
330-
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
331-
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
332-
333-
abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
334-
nnz, num_lines, as_device_type(alpha->get_const_values()),
335-
as_device_type(a->get_const_values()),
336-
a->get_const_col_idxs(),
337-
as_device_type(a->get_const_row_idxs()),
338-
as_device_type(b->get_const_values()), b->get_stride(),
339-
as_device_type(c->get_values()), c->get_stride());
340-
} else {
341-
int num_elems = ceildiv(nnz, nwarps * config::warp_size) *
342-
config::warp_size;
343-
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
344-
ceildiv(b_ncols, config::warp_size));
345-
346-
abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
347-
nnz, num_elems, as_device_type(alpha->get_const_values()),
348-
as_device_type(a->get_const_values()),
349-
a->get_const_col_idxs(),
350-
as_device_type(a->get_const_row_idxs()), b_ncols,
351-
as_device_type(b->get_const_values()), b->get_stride(),
352-
as_device_type(c->get_values()), c->get_stride());
353-
}
330+
{
331+
// TODO: b_ncols needs to be tuned for ROCm.
332+
if (b_ncols < 4) {
333+
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
334+
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
335+
336+
abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
337+
nnz, num_lines, as_device_type(alpha->get_const_values()),
338+
as_device_type(a->get_const_values()), a->get_const_col_idxs(),
339+
as_device_type(a->get_const_row_idxs()),
340+
as_device_type(b->get_const_values()), b->get_stride(),
341+
as_device_type(c->get_values()), c->get_stride());
342+
} else {
343+
int num_elems =
344+
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
345+
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
346+
ceildiv(b_ncols, config::warp_size));
347+
348+
abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
349+
nnz, num_elems, as_device_type(alpha->get_const_values()),
350+
as_device_type(a->get_const_values()), a->get_const_col_idxs(),
351+
as_device_type(a->get_const_row_idxs()), b_ncols,
352+
as_device_type(b->get_const_values()), b->get_stride(),
353+
as_device_type(c->get_values()), c->get_stride());
354354
}
355355
}
356356
}

core/test/base/dense_cache.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ TYPED_TEST(GenericDenseCache, GenericCanInitWithDifferentExecutor)
310310

311311
auto second_buffer =
312312
this->cache.template get<value_type>(another_ref, this->size);
313+
313314
ASSERT_NE(second_buffer, nullptr);
314315
GKO_ASSERT_EQUAL_DIMENSIONS(second_buffer->get_size(), this->size);
315316
ASSERT_EQ(second_buffer->get_executor(), another_ref);
@@ -322,6 +323,7 @@ TYPED_TEST(GenericDenseCache, WorkspaceIsNotCopied)
322323
{
323324
using value_type = typename TestFixture::value_type;
324325
auto buffer = this->cache.template get<value_type>(this->ref, this->size);
326+
325327
gko::detail::GenericDenseCache cache(this->cache);
326328

327329
ASSERT_EQ(cache.workspace.get_size(), 0);
@@ -333,6 +335,7 @@ TYPED_TEST(GenericDenseCache, WorkspaceIsNotMoved)
333335
{
334336
using value_type = typename TestFixture::value_type;
335337
auto buffer = this->cache.template get<value_type>(this->ref, this->size);
338+
336339
gko::detail::GenericDenseCache cache(std::move(this->cache));
337340

338341
ASSERT_EQ(cache.workspace.get_size(), 0);
@@ -345,6 +348,7 @@ TYPED_TEST(GenericDenseCache, WorkspaceIsNotCopyAssigned)
345348
using value_type = typename TestFixture::value_type;
346349
auto buffer = this->cache.template get<value_type>(this->ref, this->size);
347350
gko::detail::GenericDenseCache cache;
351+
348352
cache = this->cache;
349353

350354
ASSERT_EQ(cache.workspace.get_size(), 0);
@@ -357,6 +361,7 @@ TYPED_TEST(GenericDenseCache, WorkspaceIsNotMoveAssigned)
357361
using value_type = typename TestFixture::value_type;
358362
auto buffer = this->cache.template get<value_type>(this->ref, this->size);
359363
gko::detail::GenericDenseCache cache;
364+
360365
cache = std::move(this->cache);
361366

362367
ASSERT_EQ(cache.workspace.get_size(), 0);

include/ginkgo/core/base/dense_cache.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ struct GenericDenseCache {
117117
GenericDenseCache(GenericDenseCache&&) noexcept;
118118
GenericDenseCache& operator=(const GenericDenseCache&);
119119
GenericDenseCache& operator=(GenericDenseCache&&) noexcept;
120-
mutable array<char> workspace;
121120

122121
/**
123122
* Pointer access to the underlying vector with specific type.
@@ -127,6 +126,9 @@ struct GenericDenseCache {
127126
template <typename ValueType>
128127
std::shared_ptr<matrix::Dense<ValueType>> get(
129128
std::shared_ptr<const Executor> exec, dim<2> size) const;
129+
130+
private:
131+
mutable array<char> workspace;
130132
};
131133

132134

@@ -147,10 +149,6 @@ struct ScalarCache {
147149
ScalarCache(ScalarCache&& other) noexcept;
148150
ScalarCache& operator=(const ScalarCache& other);
149151
ScalarCache& operator=(ScalarCache&& other) noexcept;
150-
std::shared_ptr<const Executor> exec;
151-
double value;
152-
mutable std::map<std::string, std::shared_ptr<const gko::LinOp>> scalars;
153-
154152

155153
/**
156154
* Pointer access to the underlying vector with specific type.
@@ -159,6 +157,11 @@ struct ScalarCache {
159157
*/
160158
template <typename ValueType>
161159
std::shared_ptr<const matrix::Dense<ValueType>> get() const;
160+
161+
private:
162+
std::shared_ptr<const Executor> exec;
163+
double value;
164+
mutable std::map<std::string, std::shared_ptr<const gko::LinOp>> scalars;
162165
};
163166

164167

0 commit comments

Comments
 (0)