@@ -332,9 +332,13 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
332
332
@eval begin
333
333
function CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
334
334
m,n = size (csr)
335
- colPtr = (index == ' O' ) ? CUDA. ones (Cint, n+ 1 ) : CUDA. zeros (Cint, n+ 1 )
336
- rowVal = CUDA. zeros (Cint, nnz (csr))
337
- nzVal = CUDA. zeros ($ elty, nnz (csr))
335
+ colPtr = CuArray {Cint} (undef, n+ 1 )
336
+ rowVal = CuArray {Cint} (undef, nnz (csr))
337
+ nzVal = CuArray {$elty} (undef, nnz (csr))
338
+ if version () <= v " 12.6-"
339
+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
340
+ colPtr .= (index == ' O' ? 1 : 0 )
341
+ end
338
342
function bufferSize ()
339
343
out = Ref {Csize_t} (1 )
340
344
cusparseCsr2cscEx2_bufferSize (handle (), m, n, nnz (csr), nonzeros (csr),
@@ -352,9 +356,13 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
352
356
353
357
function CuSparseMatrixCSR {$elty} (csc:: CuSparseMatrixCSC{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
354
358
m,n = size (csc)
355
- rowPtr = (index == ' O' ) ? CUDA. ones (Cint, m+ 1 ) : CUDA. zeros (Cint, m+ 1 )
356
- colVal = CUDA. zeros (Cint,nnz (csc))
357
- nzVal = CUDA. zeros ($ elty,nnz (csc))
359
+ rowPtr = CuArray {Cint} (undef, m+ 1 )
360
+ colVal = CuArray {Cint} (undef, nnz (csc))
361
+ nzVal = CuArray {$elty} (undef, nnz (csc))
362
+ if version () <= v " 12.6-"
363
+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
364
+ rowPtr .= (index == ' O' ? 1 : 0 )
365
+ end
358
366
function bufferSize ()
359
367
out = Ref {Csize_t} (1 )
360
368
cusparseCsr2cscEx2_bufferSize (handle (), n, m, nnz (csc), nonzeros (csc),
@@ -379,9 +387,13 @@ for (elty, welty) in ((:Float16, :Float32),
379
387
@eval begin
380
388
function CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
381
389
m,n = size (csr)
382
- colPtr = (index == ' O' ) ? CUDA. ones (Cint, n+ 1 ) : CUDA. zeros (Cint, n+ 1 )
383
- rowVal = CUDA. zeros (Cint, nnz (csr))
384
- nzVal = CUDA. zeros ($ elty, nnz (csr))
390
+ colPtr = CuArray {Cint} (undef, n+ 1 )
391
+ rowVal = CuArray {Cint} (undef, nnz (csr))
392
+ nzVal = CuArray {$elty} (undef, nnz (csr))
393
+ if version () <= v " 12.6-"
394
+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
395
+ colPtr .= (index == ' O' ? 1 : 0 )
396
+ end
385
397
if $ elty == Float16 # broken for ComplexF16?
386
398
function bufferSize ()
387
399
out = Ref {Csize_t} (1 )
@@ -405,9 +417,13 @@ for (elty, welty) in ((:Float16, :Float32),
405
417
406
418
function CuSparseMatrixCSR {$elty} (csc:: CuSparseMatrixCSC{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
407
419
m,n = size (csc)
408
- rowPtr = (index == ' O' ) ? CUDA. ones (Cint, m+ 1 ) : CUDA. zeros (Cint, m+ 1 )
409
- colVal = CUDA. zeros (Cint,nnz (csc))
410
- nzVal = CUDA. zeros ($ elty,nnz (csc))
420
+ rowPtr = CuArray {Cint} (undef, m+ 1 )
421
+ colVal = CuArray {Cint} (undef, nnz (csc))
422
+ nzVal = CuArray {$elty} (undef, nnz (csc))
423
+ if version () <= v " 12.6-"
424
+ # JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
425
+ rowPtr .= (index == ' O' ? 1 : 0 )
426
+ end
411
427
if $ elty == Float16 # broken for ComplexF16?
412
428
function bufferSize ()
413
429
out = Ref {Csize_t} (1 )
@@ -523,9 +539,9 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
523
539
nb = cld (n, bsr. blockDim)
524
540
cudesca = CuMatrixDescriptor (' G' , ' L' , ' N' , index)
525
541
cudescc = CuMatrixDescriptor (' G' , ' L' , ' N' , indc)
526
- csrRowPtr = CUDA . zeros ( Cint, m + 1 )
527
- csrColInd = CUDA . zeros ( Cint, nnz (bsr))
528
- csrNzVal = CUDA . zeros ( $ elty, nnz (bsr))
542
+ csrRowPtr = CuArray { Cint} (undef , m + 1 )
543
+ csrColInd = CuArray { Cint} (undef , nnz (bsr))
544
+ csrNzVal = CuArray { $elty} (undef , nnz (bsr))
529
545
$ fname (handle (), bsr. dir, mb, nb,
530
546
cudesca, nonzeros (bsr), bsr. rowPtr, bsr. colVal,
531
547
bsr. blockDim, cudescc, csrNzVal, csrRowPtr,
0 commit comments