From 12a99227fd1cc335ec129826317f7520446c462e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 25 Dec 2021 23:32:16 -0500 Subject: [PATCH 1/4] unwrap generic matmul --- Project.toml | 2 +- src/OffsetArrays.jl | 2 + src/linearalgebra.jl | 94 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 src/linearalgebra.jl diff --git a/Project.toml b/Project.toml index 490aaad5..4ae0419f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.10.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] Adapt = "2, 3" @@ -24,7 +25,6 @@ DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/OffsetArrays.jl b/src/OffsetArrays.jl index a00aedff..aebcdae8 100644 --- a/src/OffsetArrays.jl +++ b/src/OffsetArrays.jl @@ -853,6 +853,8 @@ end import Adapt Adapt.adapt_structure(to, O::OffsetArray) = parent_call(x -> Adapt.adapt(to, x), O) +include("linearalgebra.jl") + if Base.VERSION >= v"1.4.2" include("precompile.jl") _precompile_() diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl new file mode 100644 index 00000000..cb479e55 --- /dev/null +++ b/src/linearalgebra.jl @@ -0,0 +1,94 @@ +using LinearAlgebra +using LinearAlgebra: MulAddMul, mul! +lapack_axes(t::AbstractChar, M::AbstractVecOrMat) = (axes(M, t=='N' ? 1 : 2), axes(M, t=='N' ? 2 : 1)) + +# The signature of this differs from LinearAlgebra's only on C +function LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, + _add::MulAddMul = MulAddMul()) + + mB_axis = Base.axes1(B) + mA_axis, nA_axis = lapack_axes(tA, A) + + if mB_axis != nA_axis + throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with axes(B) == ($(UnitRange(mB_axis)),)")) + end + if mA_axis != Base.axes1(C) + throw(DimensionMismatch("mul! got axes(C) == ($(UnitRange(Base.axes1(C))),), expected $(UnitRange(mA_axis))")) + end + + C1 = no_offset_view(C) + A1 = no_offset_view(A) + B1 = no_offset_view(B) + + if tA == 'T' + mul!(C1, transpose(A1), B1, _add.alpha, _add.beta) + elseif tA == 'C' + mul!(C1, adjoint(A1), B1, _add.alpha, _add.beta) + elseif tA == 'N' + mul!(C1, A1, B1, _add.alpha, _add.beta) + else + error("illegal char") + end + + C +end + +LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, + _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add) +LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, + _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add) + +function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, + _add::MulAddMul) + + mA_axis, nA_axis = lapack_axes(tA, A) + mB_axis, nB_axis = lapack_axes(tB, B) + + if nA_axis != mB_axis + throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with $(UnitRange(mB_axis)) from B")) + elseif mA_axis != axes(C,1) + throw(DimensionMismatch("mul! got axes(C,1) == $(UnitRange(axes(C,1))), expected $(UnitRange(mA_axis)) from A")) + elseif nB_axis != axes(C,2) + throw(DimensionMismatch("mul! got axes(C,2) == $(UnitRange(axes(C,2))), expected $(UnitRange(nB_axis)) from B")) + end + + C1 = no_offset_view(C) + A1 = no_offset_view(A) + B1 = no_offset_view(B) + + if tA == 'N' + if tB == 'N' + mul!(C1, A1, B1, _add.alpha, _add.beta) + elseif tB == 'T' + mul!(C1, A1, transpose(B1), _add.alpha, _add.beta) + elseif tB == 'C' + mul!(C1, A1, adjoint(B1), _add.alpha, _add.beta) + else + error("illegal char") + end + elseif tA == 'T' + if tB == 'N' + mul!(C1, transpose(A1), B1, _add.alpha, _add.beta) + elseif tB == 'T' + mul!(C1, transpose(A1), transpose(B1), _add.alpha, _add.beta) + elseif tB == 'C' + mul!(C1, transpose(A1), adjoint(B1), _add.alpha, _add.beta) + else + error("illegal char") + end + elseif tA == 'C' + if tB == 'N' + mul!(C1, adjoint(A1), B1, _add.alpha, _add.beta) + elseif tB == 'T' + mul!(C1, adjoint(A1), transpose(B1), _add.alpha, _add.beta) + elseif tB == 'C' + mul!(C1, adjoint(A1), adjoint(B1), _add.alpha, _add.beta) + else + error("illegal char") + end + else + error("illegal char") + end + + C +end From 6878a1427e36acc27da79da7559e707986a62eee Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 26 Dec 2021 10:03:46 -0500 Subject: [PATCH 2/4] adapt to delayed MulAddMul --- src/OffsetArrays.jl | 1 + src/linearalgebra.jl | 57 +++++++++++++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/OffsetArrays.jl b/src/OffsetArrays.jl index aebcdae8..b46d4149 100644 --- a/src/OffsetArrays.jl +++ b/src/OffsetArrays.jl @@ -651,6 +651,7 @@ if isdefined(Base, :IdentityUnitRange) no_offset_view(a::Base.Slice) = Base.Slice(UnitRange(a)) no_offset_view(S::SubArray) = view(parent(S), map(no_offset_view, parentindices(S))...) end +no_offset_view(A::PermutedDimsArray{T,N,perm,iperm,P}) where {T,N,perm,iperm,P} = PermutedDimsArray(no_offset_view(parent(A)), perm) no_offset_view(a::Array) = a no_offset_view(i::Number) = i no_offset_view(A::AbstractArray) = _no_offset_view(axes(A), A) diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index cb479e55..1b010060 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -2,9 +2,14 @@ using LinearAlgebra using LinearAlgebra: MulAddMul, mul! lapack_axes(t::AbstractChar, M::AbstractVecOrMat) = (axes(M, t=='N' ? 1 : 2), axes(M, t=='N' ? 2 : 1)) -# The signature of this differs from LinearAlgebra's only on C -function LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, - _add::MulAddMul = MulAddMul()) +# The signatures of these differs from LinearAlgebra's *only* on C. +LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, + _add::MulAddMul) = unwrap_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, + alpha, beta) = unwrap_matvecmul!(C, tA, A, B, alpha, beta) + +function unwrap_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, + alpha, beta) mB_axis = Base.axes1(B) mA_axis, nA_axis = lapack_axes(tA, A) @@ -21,11 +26,11 @@ function LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrM B1 = no_offset_view(B) if tA == 'T' - mul!(C1, transpose(A1), B1, _add.alpha, _add.beta) + mul!(C1, transpose(A1), B1, alpha, beta) elseif tA == 'C' - mul!(C1, adjoint(A1), B1, _add.alpha, _add.beta) + mul!(C1, adjoint(A1), B1, alpha, beta) elseif tA == 'N' - mul!(C1, A1, B1, _add.alpha, _add.beta) + mul!(C1, A1, B1, alpha, beta) else error("illegal char") end @@ -33,13 +38,22 @@ function LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrM C end +# The signatures of these differs from LinearAlgebra's *only* on C: +# Old path +LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, + _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, + _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) + +# New path LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, - _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add) + alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta) LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, - _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add) + alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta) -function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, - _add::MulAddMul) +# Worker +@inline function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, + alpha, beta) mA_axis, nA_axis = lapack_axes(tA, A) mB_axis, nB_axis = lapack_axes(tB, B) @@ -58,31 +72,31 @@ function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::Abst if tA == 'N' if tB == 'N' - mul!(C1, A1, B1, _add.alpha, _add.beta) + mul!(C1, A1, B1, alpha, beta) elseif tB == 'T' - mul!(C1, A1, transpose(B1), _add.alpha, _add.beta) + mul!(C1, A1, transpose(B1), alpha, beta) elseif tB == 'C' - mul!(C1, A1, adjoint(B1), _add.alpha, _add.beta) + mul!(C1, A1, adjoint(B1), alpha, beta) else error("illegal char") end elseif tA == 'T' if tB == 'N' - mul!(C1, transpose(A1), B1, _add.alpha, _add.beta) + mul!(C1, transpose(A1), B1, alpha, beta) elseif tB == 'T' - mul!(C1, transpose(A1), transpose(B1), _add.alpha, _add.beta) + mul!(C1, transpose(A1), transpose(B1), alpha, beta) elseif tB == 'C' - mul!(C1, transpose(A1), adjoint(B1), _add.alpha, _add.beta) + mul!(C1, transpose(A1), adjoint(B1), alpha, beta) else error("illegal char") end elseif tA == 'C' if tB == 'N' - mul!(C1, adjoint(A1), B1, _add.alpha, _add.beta) + mul!(C1, adjoint(A1), B1, alpha, beta) elseif tB == 'T' - mul!(C1, adjoint(A1), transpose(B1), _add.alpha, _add.beta) + mul!(C1, adjoint(A1), transpose(B1), alpha, beta) elseif tB == 'C' - mul!(C1, adjoint(A1), adjoint(B1), _add.alpha, _add.beta) + mul!(C1, adjoint(A1), adjoint(B1), alpha, beta) else error("illegal char") end @@ -92,3 +106,8 @@ function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::Abst C end + +no_offset_view(A::Adjoint) = Adjoint(no_offset_view(parent(A))) +no_offset_view(A::Transpose) = Transpose(no_offset_view(parent(A))) +no_offset_view(D::Diagonal) = Diagonal(no_offset_view(parent(D))) + From 6e0507522a991bfb35036de5ea0de65ebaefdf0d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 26 Dec 2021 10:56:54 -0500 Subject: [PATCH 3/4] fix Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4ae0419f..4655ac0b 100644 --- a/Project.toml +++ b/Project.toml @@ -29,4 +29,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CatIndices", "DistributedArrays", "DelimitedFiles", "Documenter", "Test", "LinearAlgebra", "EllipsisNotation", "StaticArrays", "FillArrays"] +test = ["Aqua", "CatIndices", "DistributedArrays", "DelimitedFiles", "Documenter", "Test", "EllipsisNotation", "StaticArrays", "FillArrays"] From 816b928803ffb2c66d6c544864a4c4f7eda4d9ff Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 27 Dec 2021 23:12:07 -0500 Subject: [PATCH 4/4] use transpose not T etc for generic_matmul --- src/linearalgebra.jl | 97 +++++++++----------------------------------- 1 file changed, 20 insertions(+), 77 deletions(-) diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 1b010060..2815ca41 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -1,18 +1,14 @@ using LinearAlgebra -using LinearAlgebra: MulAddMul, mul! -lapack_axes(t::AbstractChar, M::AbstractVecOrMat) = (axes(M, t=='N' ? 1 : 2), axes(M, t=='N' ? 2 : 1)) +using LinearAlgebra: MulAddMul, mul!, AdjOrTrans -# The signatures of these differs from LinearAlgebra's *only* on C. -LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, - _add::MulAddMul) = unwrap_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) -LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, - alpha, beta) = unwrap_matvecmul!(C, tA, A, B, alpha, beta) +@inline LinearAlgebra.generic_matvecmul!(C::OffsetVector, fA::Function, A::AbstractVecOrMat, B::AbstractVector, + alpha, beta) = unwrap_matvecmul!(C, fA, A, B, alpha, beta) -function unwrap_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector, +@inline function unwrap_matvecmul!(C::OffsetVector, fA, A::AbstractVecOrMat, B::AbstractVector, alpha, beta) mB_axis = Base.axes1(B) - mA_axis, nA_axis = lapack_axes(tA, A) + mA_axis, nA_axis = axes(fA(A)) if mB_axis != nA_axis throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with axes(B) == ($(UnitRange(mB_axis)),)")) @@ -21,42 +17,25 @@ function unwrap_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::Abstract throw(DimensionMismatch("mul! got axes(C) == ($(UnitRange(Base.axes1(C))),), expected $(UnitRange(mA_axis))")) end - C1 = no_offset_view(C) - A1 = no_offset_view(A) - B1 = no_offset_view(B) - - if tA == 'T' - mul!(C1, transpose(A1), B1, alpha, beta) - elseif tA == 'C' - mul!(C1, adjoint(A1), B1, alpha, beta) - elseif tA == 'N' - mul!(C1, A1, B1, alpha, beta) - else - error("illegal char") - end - + mul!(no_offset_view(C), fA(no_offset_view(A)), no_offset_view(B), alpha, beta) C end # The signatures of these differs from LinearAlgebra's *only* on C: -# Old path -LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, - _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) -LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, - _add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +@inline LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, fA::Function, fB::Function, A::AbstractMatrix, B::AbstractMatrix, + alpha, beta) = unwrap_matmatmul!(C, fA, fB, A, B, alpha, beta) + +@inline LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, fA::Function, fB::Function, A::AbstractVecOrMat, B::AbstractVecOrMat, + alpha, beta) = unwrap_matmatmul!(C, fA, fB, A, B, alpha, beta) -# New path -LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, - alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta) -LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, - alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta) +@inline LinearAlgebra.generic_matmatmul!(C::AdjOrTrans{<:Any, <:OffsetArray}, fA::Function, fB::Function, A::AbstractMatrix, B::AbstractMatrix, + alpha, beta) = unwrap_matmatmul!(C, fA, fB, A, B, alpha, beta) -# Worker -@inline function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, +@inline function unwrap_matmatmul!(C::AbstractVecOrMat, fA, fB, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha, beta) - mA_axis, nA_axis = lapack_axes(tA, A) - mB_axis, nB_axis = lapack_axes(tB, B) + mA_axis, nA_axis = axes(fA(A)) + mB_axis, nB_axis = axes(fB(B)) if nA_axis != mB_axis throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with $(UnitRange(mB_axis)) from B")) @@ -66,48 +45,12 @@ LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A throw(DimensionMismatch("mul! got axes(C,2) == $(UnitRange(axes(C,2))), expected $(UnitRange(nB_axis)) from B")) end - C1 = no_offset_view(C) - A1 = no_offset_view(A) - B1 = no_offset_view(B) - - if tA == 'N' - if tB == 'N' - mul!(C1, A1, B1, alpha, beta) - elseif tB == 'T' - mul!(C1, A1, transpose(B1), alpha, beta) - elseif tB == 'C' - mul!(C1, A1, adjoint(B1), alpha, beta) - else - error("illegal char") - end - elseif tA == 'T' - if tB == 'N' - mul!(C1, transpose(A1), B1, alpha, beta) - elseif tB == 'T' - mul!(C1, transpose(A1), transpose(B1), alpha, beta) - elseif tB == 'C' - mul!(C1, transpose(A1), adjoint(B1), alpha, beta) - else - error("illegal char") - end - elseif tA == 'C' - if tB == 'N' - mul!(C1, adjoint(A1), B1, alpha, beta) - elseif tB == 'T' - mul!(C1, adjoint(A1), transpose(B1), alpha, beta) - elseif tB == 'C' - mul!(C1, adjoint(A1), adjoint(B1), alpha, beta) - else - error("illegal char") - end - else - error("illegal char") - end - + # Must be sure `no_offset_view(C)` won't match signature above! + mul!(no_offset_view(C), fA(no_offset_view(A)), fB(no_offset_view(B)), alpha, beta) C end -no_offset_view(A::Adjoint) = Adjoint(no_offset_view(parent(A))) -no_offset_view(A::Transpose) = Transpose(no_offset_view(parent(A))) +no_offset_view(A::Adjoint) = adjoint(no_offset_view(parent(A))) +no_offset_view(A::Transpose) = transpose(no_offset_view(parent(A))) no_offset_view(D::Diagonal) = Diagonal(no_offset_view(parent(D)))