Skip to content

Commit 6ae58bf

Browse files
committed
Reduce scope of workaround.
1 parent 95a2eb6 commit 6ae58bf

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

lib/cusparse/conversions.jl

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,13 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
332332
@eval begin
333333
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
334334
m,n = size(csr)
335-
colPtr = (index == 'O') ? CUDA.ones(Cint, n+1) : CUDA.zeros(Cint, n+1)
336-
rowVal = CUDA.zeros(Cint, nnz(csr))
337-
nzVal = CUDA.zeros($elty, nnz(csr))
335+
colPtr = CuArray{Cint}(undef, n+1)
336+
rowVal = CuArray{Cint}(undef, nnz(csr))
337+
nzVal = CuArray{$elty}(undef, nnz(csr))
338+
if version() <= v"12.6-"
339+
# JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
340+
colPtr .= (index == 'O' ? 1 : 0)
341+
end
338342
function bufferSize()
339343
out = Ref{Csize_t}(1)
340344
cusparseCsr2cscEx2_bufferSize(handle(), m, n, nnz(csr), nonzeros(csr),
@@ -352,9 +356,13 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
352356

353357
function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
354358
m,n = size(csc)
355-
rowPtr = (index == 'O') ? CUDA.ones(Cint, m+1) : CUDA.zeros(Cint, m+1)
356-
colVal = CUDA.zeros(Cint,nnz(csc))
357-
nzVal = CUDA.zeros($elty,nnz(csc))
359+
rowPtr = CuArray{Cint}(undef, m+1)
360+
colVal = CuArray{Cint}(undef, nnz(csc))
361+
nzVal = CuArray{$elty}(undef, nnz(csc))
362+
if version() <= v"12.6-"
363+
# JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
364+
rowPtr .= (index == 'O' ? 1 : 0)
365+
end
358366
function bufferSize()
359367
out = Ref{Csize_t}(1)
360368
cusparseCsr2cscEx2_bufferSize(handle(), n, m, nnz(csc), nonzeros(csc),
@@ -379,9 +387,13 @@ for (elty, welty) in ((:Float16, :Float32),
379387
@eval begin
380388
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
381389
m,n = size(csr)
382-
colPtr = (index == 'O') ? CUDA.ones(Cint, n+1) : CUDA.zeros(Cint, n+1)
383-
rowVal = CUDA.zeros(Cint, nnz(csr))
384-
nzVal = CUDA.zeros($elty, nnz(csr))
390+
colPtr = CuArray{Cint}(undef, n+1)
391+
rowVal = CuArray{Cint}(undef, nnz(csr))
392+
nzVal = CuArray{$elty}(undef, nnz(csr))
393+
if version() <= v"12.6-"
394+
# JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
395+
colPtr .= (index == 'O' ? 1 : 0)
396+
end
385397
if $elty == Float16 #broken for ComplexF16?
386398
function bufferSize()
387399
out = Ref{Csize_t}(1)
@@ -405,9 +417,13 @@ for (elty, welty) in ((:Float16, :Float32),
405417

406418
function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
407419
m,n = size(csc)
408-
rowPtr = (index == 'O') ? CUDA.ones(Cint, m+1) : CUDA.zeros(Cint, m+1)
409-
colVal = CUDA.zeros(Cint,nnz(csc))
410-
nzVal = CUDA.zeros($elty,nnz(csc))
420+
rowPtr = CuArray{Cint}(undef, m+1)
421+
colVal = CuArray{Cint}(undef, nnz(csc))
422+
nzVal = CuArray{$elty}(undef, nnz(csc))
423+
if version() <= v"12.6-"
424+
# JuliaGPU/CUDA.jl#2806 (NVBUG 5384319)
425+
rowPtr .= (index == 'O' ? 1 : 0)
426+
end
411427
if $elty == Float16 #broken for ComplexF16?
412428
function bufferSize()
413429
out = Ref{Csize_t}(1)
@@ -523,9 +539,9 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
523539
nb = cld(n, bsr.blockDim)
524540
cudesca = CuMatrixDescriptor('G', 'L', 'N', index)
525541
cudescc = CuMatrixDescriptor('G', 'L', 'N', indc)
526-
csrRowPtr = CUDA.zeros(Cint, m + 1)
527-
csrColInd = CUDA.zeros(Cint, nnz(bsr))
528-
csrNzVal = CUDA.zeros($elty, nnz(bsr))
542+
csrRowPtr = CuArray{Cint}(undef, m + 1)
543+
csrColInd = CuArray{Cint}(undef, nnz(bsr))
544+
csrNzVal = CuArray{$elty}(undef, nnz(bsr))
529545
$fname(handle(), bsr.dir, mb, nb,
530546
cudesca, nonzeros(bsr), bsr.rowPtr, bsr.colVal,
531547
bsr.blockDim, cudescc, csrNzVal, csrRowPtr,

0 commit comments

Comments
 (0)