@@ -266,38 +266,39 @@ void spmv2(std::shared_ptr<const DefaultExecutor> exec,
266
266
const auto b_ncols = b->get_size ()[1 ];
267
267
const dim3 coo_block (config::warp_size, warps_in_block, 1 );
268
268
const auto nwarps = host_kernel::calculate_nwarps (exec, nnz);
269
- if (nwarps > 0 && b_ncols > 0 ) {
270
- // not support 16 bit atomic
269
+ if (nwarps <= 0 && b_ncols <= 0 ) {
270
+ return ;
271
+ }
272
+ // not support 16 bit atomic
271
273
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
272
- if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
273
- GKO_NOT_SUPPORTED (c);
274
- } else
274
+ if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
275
+ GKO_NOT_SUPPORTED (c);
276
+ } else
275
277
#endif
276
- {
277
- // TODO: b_ncols needs to be tuned for ROCm.
278
- if (b_ncols < 4 ) {
279
- const dim3 coo_grid (ceildiv (nwarps, warps_in_block), b_ncols);
280
- int num_lines = ceildiv (nnz, nwarps * config::warp_size);
281
-
282
- abstract_spmv<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
283
- nnz, num_lines, as_device_type (a->get_const_values ()),
284
- a->get_const_col_idxs (),
285
- as_device_type (a->get_const_row_idxs ()),
286
- as_device_type (b->get_const_values ()), b->get_stride (),
287
- as_device_type (c->get_values ()), c->get_stride ());
288
- } else {
289
- int num_elems = ceildiv (nnz, nwarps * config::warp_size) *
290
- config::warp_size;
291
- const dim3 coo_grid (ceildiv (nwarps, warps_in_block),
292
- ceildiv (b_ncols, config::warp_size));
293
-
294
- abstract_spmm<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
295
- nnz, num_elems, as_device_type (a->get_const_values ()),
296
- a->get_const_col_idxs (),
297
- as_device_type (a->get_const_row_idxs ()), b_ncols,
298
- as_device_type (b->get_const_values ()), b->get_stride (),
299
- as_device_type (c->get_values ()), c->get_stride ());
300
- }
278
+ {
279
+ // TODO: b_ncols needs to be tuned for ROCm.
280
+ if (b_ncols < 4 ) {
281
+ const dim3 coo_grid (ceildiv (nwarps, warps_in_block), b_ncols);
282
+ int num_lines = ceildiv (nnz, nwarps * config::warp_size);
283
+
284
+ abstract_spmv<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
285
+ nnz, num_lines, as_device_type (a->get_const_values ()),
286
+ a->get_const_col_idxs (),
287
+ as_device_type (a->get_const_row_idxs ()),
288
+ as_device_type (b->get_const_values ()), b->get_stride (),
289
+ as_device_type (c->get_values ()), c->get_stride ());
290
+ } else {
291
+ int num_elems =
292
+ ceildiv (nnz, nwarps * config::warp_size) * config::warp_size;
293
+ const dim3 coo_grid (ceildiv (nwarps, warps_in_block),
294
+ ceildiv (b_ncols, config::warp_size));
295
+
296
+ abstract_spmm<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
297
+ nnz, num_elems, as_device_type (a->get_const_values ()),
298
+ a->get_const_col_idxs (),
299
+ as_device_type (a->get_const_row_idxs ()), b_ncols,
300
+ as_device_type (b->get_const_values ()), b->get_stride (),
301
+ as_device_type (c->get_values ()), c->get_stride ());
301
302
}
302
303
}
303
304
}
@@ -317,40 +318,39 @@ void advanced_spmv2(std::shared_ptr<const DefaultExecutor> exec,
317
318
const dim3 coo_block (config::warp_size, warps_in_block, 1 );
318
319
const auto b_ncols = b->get_size ()[1 ];
319
320
320
- if (nwarps > 0 && b_ncols > 0 ) {
321
- // not support 16 bit atomic
321
+ if (nwarps <= 0 && b_ncols <= 0 ) {
322
+ return ;
323
+ }
324
+ // not support 16 bit atomic
322
325
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
323
- if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
324
- GKO_NOT_SUPPORTED (c);
325
- } else
326
+ if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
327
+ GKO_NOT_SUPPORTED (c);
328
+ } else
326
329
#endif
327
- {
328
- // TODO: b_ncols needs to be tuned for ROCm.
329
- if (b_ncols < 4 ) {
330
- int num_lines = ceildiv (nnz, nwarps * config::warp_size);
331
- const dim3 coo_grid (ceildiv (nwarps, warps_in_block), b_ncols);
332
-
333
- abstract_spmv<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
334
- nnz, num_lines, as_device_type (alpha->get_const_values ()),
335
- as_device_type (a->get_const_values ()),
336
- a->get_const_col_idxs (),
337
- as_device_type (a->get_const_row_idxs ()),
338
- as_device_type (b->get_const_values ()), b->get_stride (),
339
- as_device_type (c->get_values ()), c->get_stride ());
340
- } else {
341
- int num_elems = ceildiv (nnz, nwarps * config::warp_size) *
342
- config::warp_size;
343
- const dim3 coo_grid (ceildiv (nwarps, warps_in_block),
344
- ceildiv (b_ncols, config::warp_size));
345
-
346
- abstract_spmm<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
347
- nnz, num_elems, as_device_type (alpha->get_const_values ()),
348
- as_device_type (a->get_const_values ()),
349
- a->get_const_col_idxs (),
350
- as_device_type (a->get_const_row_idxs ()), b_ncols,
351
- as_device_type (b->get_const_values ()), b->get_stride (),
352
- as_device_type (c->get_values ()), c->get_stride ());
353
- }
330
+ {
331
+ // TODO: b_ncols needs to be tuned for ROCm.
332
+ if (b_ncols < 4 ) {
333
+ int num_lines = ceildiv (nnz, nwarps * config::warp_size);
334
+ const dim3 coo_grid (ceildiv (nwarps, warps_in_block), b_ncols);
335
+
336
+ abstract_spmv<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
337
+ nnz, num_lines, as_device_type (alpha->get_const_values ()),
338
+ as_device_type (a->get_const_values ()), a->get_const_col_idxs (),
339
+ as_device_type (a->get_const_row_idxs ()),
340
+ as_device_type (b->get_const_values ()), b->get_stride (),
341
+ as_device_type (c->get_values ()), c->get_stride ());
342
+ } else {
343
+ int num_elems =
344
+ ceildiv (nnz, nwarps * config::warp_size) * config::warp_size;
345
+ const dim3 coo_grid (ceildiv (nwarps, warps_in_block),
346
+ ceildiv (b_ncols, config::warp_size));
347
+
348
+ abstract_spmm<<<coo_grid, coo_block, 0 , exec->get_stream ()>>>(
349
+ nnz, num_elems, as_device_type (alpha->get_const_values ()),
350
+ as_device_type (a->get_const_values ()), a->get_const_col_idxs (),
351
+ as_device_type (a->get_const_row_idxs ()), b_ncols,
352
+ as_device_type (b->get_const_values ()), b->get_stride (),
353
+ as_device_type (c->get_values ()), c->get_stride ());
354
354
}
355
355
}
356
356
}
0 commit comments