Skip to content

Commit ace60cf

Browse files
authored
Make kron of AdjOrTrans sparse matrices sparse (#42181)
1 parent 2802807 commit ace60cf

File tree

3 files changed

+52
-24
lines changed

3 files changed

+52
-24
lines changed

stdlib/SparseArrays/src/SparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape, tail,
99
require_one_based_indexing
1010
using Base.Sort: Forward
1111
using LinearAlgebra
12+
using LinearAlgebra: AdjOrTrans
1213

1314
import Base: +, -, *, \, /, &, |, xor, ==, zero
1415
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,

stdlib/SparseArrays/src/linalg.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,15 @@ end
13631363
end
13641364
return C
13651365
end
1366-
1366+
@inline function kron!(C::SparseMatrixCSC, A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AbstractSparseMatrixCSC)
1367+
return kron!(C, copy(A), B)
1368+
end
1369+
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC})
1370+
return kron!(C, A, copy(B))
1371+
end
1372+
@inline function kron!(C::SparseMatrixCSC, A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC})
1373+
return kron!(C, copy(A), copy(B))
1374+
end
13671375
@inline function kron!(z::SparseVector, x::SparseVector, y::SparseVector)
13681376
nnzx = nnz(x); nnzy = nnz(y);
13691377
nzind = nonzeroinds(z)
@@ -1391,6 +1399,11 @@ function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S
13911399
sizehint!(C, nnz(A)*nnz(B))
13921400
return @inbounds kron!(C, A, B)
13931401
end
1402+
kron(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AbstractSparseMatrixCSC) = kron(copy(A), B)
1403+
kron(A::AbstractSparseMatrixCSC, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = kron(A, copy(B))
1404+
function kron(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC})
1405+
return kron(copy(A), copy(B))
1406+
end
13941407

13951408
# sparse vector ⊗ sparse vector
13961409
function kron(x::SparseVector{T1,S1}, y::SparseVector{T2,S2}) where {T1,S1,T2,S2}
@@ -1407,21 +1420,29 @@ Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, x
14071420
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, x::SparseVector, A::AbstractSparseMatrixCSC) = kron!(C, SparseMatrixCSC(x), A)
14081421

14091422
kron(A::AbstractSparseMatrixCSC, x::SparseVector) = kron(A, SparseMatrixCSC(x))
1423+
kron(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, x::SparseVector) =
1424+
kron(copy(A), x)
14101425
kron(x::SparseVector, A::AbstractSparseMatrixCSC) = kron(SparseMatrixCSC(x), A)
1426+
kron(x::SparseVector, A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) =
1427+
kron(x, copy(A))
14111428

14121429
# sparse vec/mat ⊗ vec/mat and vice versa
14131430
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron!(C, A, sparse(B))
14141431
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron!(C, sparse(A), B)
14151432

1416-
kron(A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
1417-
kron(A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron(sparse(A), B)
1433+
kron(A::Union{SparseVector,AbstractSparseMatrixCSC,AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, B::VecOrMat) =
1434+
kron(A, sparse(B))
1435+
kron(A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC,AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}) =
1436+
kron(sparse(A), B)
14181437

14191438
# sparse vec/mat ⊗ Diagonal and vice versa
14201439
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron!(C, sparse(A), B)
14211440
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron!(C, A, sparse(B))
14221441

1423-
kron(A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron(sparse(A), B)
1424-
kron(A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron(A, sparse(B))
1442+
kron(A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}, AdjOrTrans{S,<:AbstractSparseMatrixCSC}}) where {T<:Number, S<:Number} =
1443+
kron(sparse(A), B)
1444+
kron(A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}, AdjOrTrans{S,<:AbstractSparseMatrixCSC}}, B::Diagonal{S}) where {T<:Number, S<:Number} =
1445+
kron(A, sparse(B))
14251446

14261447
# sparse outer product
14271448
kron!(C::SparseMatrixCSC, A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = broadcast!(*, C, A, B)

stdlib/SparseArrays/test/sparse.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -454,29 +454,35 @@ end
454454
c_di = Diagonal(rand(m)); c = sparse(c_di); c_d = Array(c_di)
455455
d_di = Diagonal(rand(n)); d = sparse(d_di); d_d = Array(d_di)
456456
# mat ⊗ mat
457-
@test Array(kron(a, b)) == kron(a_d, b_d)
458-
@test Array(kron(a_d, b)) == kron(a_d, b_d)
459-
@test Array(kron(a, b_d)) == kron(a_d, b_d)
460-
@test issparse(kron(c, d_di))
461-
@test Array(kron(c, d_di)) == kron(c_d, d_d)
462-
@test issparse(kron(c_di, d))
463-
@test Array(kron(c_di, d)) == kron(c_d, d_d)
464-
@test issparse(kron(c_di, y))
465-
@test Array(kron(c_di, y)) == kron(c_di, y_d)
466-
@test issparse(kron(x, d_di))
467-
@test Array(kron(x, d_di)) == kron(x_d, d_di)
457+
for t in (identity, adjoint, transpose)
458+
@test Array(kron(t(a), b)::SparseMatrixCSC) == kron(t(a_d), b_d)
459+
@test Array(kron(a, t(b))::SparseMatrixCSC) == kron(a_d, t(b_d))
460+
@test Array(kron(t(a), t(b))::SparseMatrixCSC) == kron(t(a_d), t(b_d))
461+
@test Array(kron(a_d, t(b))::SparseMatrixCSC) == kron(a_d, t(b_d))
462+
@test Array(kron(t(a), b_d)::SparseMatrixCSC) == kron(t(a_d), b_d)
463+
@test issparse(kron(c, d_di))
464+
@test Array(kron(c, d_di)) == kron(c_d, d_d)
465+
@test issparse(kron(c_di, d))
466+
@test Array(kron(c_di, d)) == kron(c_d, d_d)
467+
@test issparse(kron(c_di, y))
468+
@test Array(kron(c_di, y)) == kron(c_di, y_d)
469+
@test issparse(kron(x, d_di))
470+
@test Array(kron(x, d_di)) == kron(x_d, d_di)
471+
end
468472
# vec ⊗ vec
469473
@test Vector(kron(x, y)) == kron(x_d, y_d)
470474
@test Vector(kron(x_d, y)) == kron(x_d, y_d)
471475
@test Vector(kron(x, y_d)) == kron(x_d, y_d)
472-
# mat ⊗ vec
473-
@test Array(kron(a, y)) == kron(a_d, y_d)
474-
@test Array(kron(a_d, y)) == kron(a_d, y_d)
475-
@test Array(kron(a, y_d)) == kron(a_d, y_d)
476-
# vec ⊗ mat
477-
@test Array(kron(x, b)) == kron(x_d, b_d)
478-
@test Array(kron(x_d, b)) == kron(x_d, b_d)
479-
@test Array(kron(x, b_d)) == kron(x_d, b_d)
476+
for t in (identity, adjoint, transpose)
477+
# mat ⊗ vec
478+
@test Array(kron(t(a), y)::SparseMatrixCSC) == kron(t(a_d), y_d)
479+
@test Array(kron(t(a_d), y)) == kron(t(a_d), y_d)
480+
@test Array(kron(t(a), y_d)::SparseMatrixCSC) == kron(t(a_d), y_d)
481+
# vec ⊗ mat
482+
@test Array(kron(x, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d))
483+
@test Array(kron(x_d, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d))
484+
@test Array(kron(x, t(b_d))) == kron(x_d, t(b_d))
485+
end
480486
# vec ⊗ vec'
481487
@test issparse(kron(v, y'))
482488
@test issparse(kron(x, y'))

0 commit comments

Comments
 (0)