Skip to content

gemm! alias #202

@harrisonritz

Description

@harrisonritz

love the package!

BLAS.gemm! fails for any PDMat arguments unless you pass a.mat.
Maybe something like could be more general:

pd_gemm!(tA, tB, alpha, A, B, beta, C) =    BLAS.gemm!( tA, tB, alpha, 
                                                        A isa AbstractPDMat ? A.mat : A, 
                                                        B isa AbstractPDMat ? B.mat : B, 
                                                        beta, 
                                                        C isa AbstractPDMat ? C.mat : C);

Benchmarks seem to run just as fast.
minimal example:

using LinearAlgebra, PDMats, BenchmarkTools


ix = randn(20,20);
xx = PDMat(Hermitian(ix'*ix));
aa = randn(20,20);

pd_gemm!(tA, tB, alpha, A, B, beta, C) =    BLAS.gemm!( tA, tB, alpha, 
                                                        A isa AbstractPDMat ? A.mat : A, 
                                                        B isa AbstractPDMat ? B.mat : B, 
                                                        beta, 
                                                        C isa AbstractPDMat ? C.mat : C);

yy = zeros(20,20);
@benchmark mul!($yy, $xx, $aa', 1.0, 1.0)

yy = zeros(20,20);
@benchmark BLAS.gemm!('N', 'T', 1.0, $xx.mat, $aa, 1.0, $aa)

yy = zeros(20,20);
@benchmark pd_gemm!('N', 'T', 1.0, $xx, $aa, 1.0, $yy)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min  max):  2.384 μs  331.486 μs  ┊ GC (min  max):  0.00%  98.48%
 Time  (median):     3.176 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.285 μs ±  15.201 μs  ┊ GC (mean ± σ):  22.44% ±  6.26%

            █▂                                                 
  █▂▁▁▁▁▁▁▂▇██▇▆▅▄▃▃▂▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  2.38 μs         Histogram: frequency by time         6.3 μs <

 Memory estimate: 20.58 KiB, allocs estimate: 4.

BenchmarkTools.Trial: 10000 samples with 193 evaluations.
 Range (min  max):  505.394 ns  993.523 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     506.477 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   518.730 ns ±  28.264 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▂▁▁▁▁▄▃▁▁▁▂▂▁▂▁▁                                             ▁
  ██████████████████▇▇▇▆▆▆▇▇▆▆▆▆▆▆▇▆▇▆▆▆▆▆▆▆▆▅▆▆▅▆▆▅▅▄▄▅▃▄▅▄▅▄▅ █
  505 ns        Histogram: log(frequency) by time        642 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

BenchmarkTools.Trial: 10000 samples with 194 evaluations.
 Range (min  max):  505.371 ns  909.577 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     506.443 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   510.625 ns ±  13.425 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▆▂▁      ▁  ▂▃▁                                              ▁
  ████████▇███▇███▇▆▆▆▇▇▇▇▇▇▇▆▆▆▆▅▆▆▆▆▆▅▄▅▅▄▅▅▄▄▄▃▅▄▃▄▅▅▃▄▄▄▂▃▃ █
  505 ns        Histogram: log(frequency) by time        573 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions