Skip to content

Commit 4ee7ddc

Browse files
author
RAYNAUD Paul (raynaudp)
committed
allocation free for the core of the mul! method
1 parent 24cc4b9 commit 4ee7ddc

File tree

2 files changed

+53
-36
lines changed

2 files changed

+53
-36
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
33
version = "2.4.1"
44

55
[deps]
6+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
67
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
78
LDLFactorizations = "40e66cde-538c-5869-a4ad-c39174c6795b"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -11,6 +12,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1213

1314
[compat]
15+
CUDA = "3.12.1"
1416
FastClosures = "0.2, 0.3"
1517
LDLFactorizations = "0.8.1, 0.9, 0.10"
1618
TimerOutputs = "^0.5"

src/compressed_lbfgs.jl

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ Compressed LBFGS implementation from:
77
Implemented by Paul Raynaud (supervised by Dominique Orban)
88
=#
99

10-
using LinearAlgebra
10+
using LinearAlgebra, LinearAlgebra.BLAS
11+
using CUDA
1112

1213
export CompressedLBFGS
1314

@@ -26,14 +27,16 @@ mutable struct CompressedLBFGS{T, M<:AbstractMatrix{T}, V<:AbstractVector{T}}
2627
intermediate_2::LowerTriangular{T,M} # 2m * 2m
2728
inverse_intermediate_1::UpperTriangular{T,M} # 2m * 2m
2829
inverse_intermediate_2::LowerTriangular{T,M} # 2m * 2m
30+
intermediary_vector::V # 2m
2931
sol::V # m
30-
inverse::Bool
32+
intermediate_structure_updated::Bool
3133
end
3234

35+
default_gpu() = CUDA.functional() ? true : false
3336
default_matrix_type(gpu::Bool, T::DataType) = gpu ? CuMatrix{T} : Matrix{T}
3437
default_vector_type(gpu::Bool, T::DataType) = gpu ? CuVector{T} : Vector{T}
3538

36-
function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=false, M=default_matrix_type(gpu,T), V=default_vector_type(gpu,T))
39+
function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=default_gpu(), M=default_matrix_type(gpu,T), V=default_vector_type(gpu,T))
3740
α = (T)(1)
3841
k = 0
3942
Sₖ = M(undef,n,m)
@@ -46,9 +49,10 @@ function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=false, M=default_matrix_
4649
intermediate_2 = LowerTriangular(M(undef,2*m,2*m))
4750
inverse_intermediate_1 = UpperTriangular(M(undef,2*m,2*m))
4851
inverse_intermediate_2 = LowerTriangular(M(undef,2*m,2*m))
52+
intermediary_vector = V(undef,2*m)
4953
sol = V(undef,2*m)
50-
inverse = false
51-
return CompressedLBFGS{T,M,V}(m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, sol, inverse)
54+
intermediate_structure_updated = false
55+
return CompressedLBFGS{T,M,V}(m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, intermediary_vector, sol, intermediate_structure_updated)
5256
end
5357

5458
function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:AbstractVector{T}}
@@ -69,13 +73,14 @@ function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:Abstra
6973
circshift(op.Yₖ, (0,-1))
7074
circshift(op.Dₖ, (-1,-1))
7175
# circshift doesn't work for a LowerTriangular matrix
76+
# for the time being, reinstantiate completely the Lₖ matrix
7277
for j in 2:op.k
7378
for i in 1:j-1
7479
op.Lₖ.data[j, i] = dot(op.Sₖ[:,j],op.Yₖ[:,i])
7580
end
7681
end
7782
end
78-
op.inverse = false
83+
op.intermediate_structure_updated = false
7984
return op
8085
end
8186

@@ -98,46 +103,56 @@ function Base.Matrix(op::CompressedLBFGS{T,M,V}) where {T,M,V}
98103
return Bₖ
99104
end
100105

106+
# step 4, Jₖ is computed only if needed
101107
function inverse_cholesky(op::CompressedLBFGS)
102-
if !op.inverse
103-
op.chol_matrix[1:op.k,1:op.k] .= op.α .* (transpose(op.Sₖ[:,1:op.k]) * op.Sₖ[:,1:op.k]) .+ op.Lₖ[1:op.k,1:op.k] * inv(op.Dₖ[1:op.k,1:op.k]) * transpose(op.Lₖ[1:op.k,1:op.k])
104-
cholesky!(view(op.chol_matrix,1:op.k,1:op.k))
105-
op.inverse = true
106-
end
107-
Jₖ = transpose(UpperTriangular(op.chol_matrix[1:op.k,1:op.k]))
108+
view(op.chol_matrix, 1:op.k, 1:op.k) .= op.α .* (transpose(view(op.Sₖ, :, 1:op.k)) * view(op.Sₖ, :, 1:op.k)) .+ view(op.Lₖ, 1:op.k, 1:op.k) * inv(op.Dₖ[1:op.k, 1:op.k]) * transpose(view(op.Lₖ, 1:op.k, 1:op.k))
109+
cholesky!(view(op.chol_matrix,1:op.k,1:op.k))
110+
Jₖ = transpose(UpperTriangular(view(op.chol_matrix, 1:op.k, 1:op.k)))
108111
return Jₖ
109112
end
110113

114+
# step 6, must be improve
115+
function precompile_iterated_structure!(op::CompressedLBFGS)
116+
Jₖ = inverse_cholesky(op)
117+
118+
view(op.intermediate_1, 1:op.k,1:op.k) .= .- view(op.Dₖ, 1:op.k, 1:op.k)^(1/2)
119+
view(op.intermediate_1, 1:op.k,op.k+1:2*op.k) .= view(op.Dₖ, 1:op.k, 1:op.k)^(-1/2) * transpose(view(op.Lₖ, 1:op.k, 1:op.k))
120+
view(op.intermediate_1, op.k+1:2*op.k, 1:op.k) .= 0
121+
view(op.intermediate_1, op.k+1:2*op.k, op.k+1:2*op.k) .= transpose(Jₖ)
122+
123+
view(op.intermediate_2, 1:op.k, 1:op.k) .= view(op.Dₖ, 1:op.k, 1:op.k)^(1/2)
124+
view(op.intermediate_2, 1:op.k, op.k+1:2*op.k) .= 0
125+
view(op.intermediate_2, op.k+1:2*op.k, 1:op.k) .= .- view(op.Lₖ, 1:op.k, 1:op.k) * view(op.Dₖ, 1:op.k, 1:op.k)^(-1/2)
126+
view(op.intermediate_2, op.k+1:2*op.k, op.k+1:2*op.k) .= Jₖ
127+
128+
view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_1[ 1:2*op.k,1:2*op.k])
129+
view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_2[ 1:2*op.k,1:2*op.k])
130+
131+
op.intermediate_structure_updated = true
132+
end
133+
111134
# Algorithm 3.2 (p15)
112135
function LinearAlgebra.mul!(Bv::V, op::CompressedLBFGS{T,M,V}, v::V) where {T,M,V<:AbstractVector{T}}
113136
# step 1-3 mainly done by Base.push!
114-
# step 4, Jₖ is computed only if needed
115-
Jₖ = inverse_cholesky(op::CompressedLBFGS)
137+
138+
# steps 4 and 6, in case the intermediary required structure are not up to date
139+
(!op.intermediate_structure_updated) && (precompile_iterated_structure!(op))
116140

117141
# step 5, try views for mul!
118-
# mul!(op.sol[1:op.k], transpose(op.Yₖ[:,1:op.k]), v) # wrong result
119-
# mul!(op.sol[op.k+1:2*op.k], transpose(op.Yₖ[:,1:op.k]), v, (T)(1), op.α) # wrong result
120-
op.sol[1:op.k] .= transpose(op.Yₖ[:,1:op.k]) * v
121-
op.sol[op.k+1:2*op.k] .= op.α .* transpose(op.Sₖ[:,1:op.k]) * v
122-
123-
# step 6, must be improve
124-
op.intermediate_1[1:op.k,1:op.k] .= .- op.Dₖ[1:op.k,1:op.k]^(1/2)
125-
op.intermediate_1[1:op.k,op.k+1:2*op.k] .= op.Dₖ[1:op.k,1:op.k]^(-1/2) * transpose(op.Lₖ[1:op.k,1:op.k])
126-
op.intermediate_1[op.k+1:2*op.k,1:op.k] .= 0
127-
op.intermediate_1[op.k+1:2*op.k,op.k+1:2*op.k] .= transpose(Jₖ)
128-
129-
op.intermediate_2[1:op.k,1:op.k] .= op.Dₖ[1:op.k,1:op.k]^(1/2)
130-
op.intermediate_2[1:op.k,op.k+1:2*op.k] .= 0
131-
op.intermediate_2[op.k+1:2*op.k,1:op.k] .= .- op.Lₖ[1:op.k,1:op.k] * op.Dₖ[1:op.k,1:op.k]^(-1/2)
132-
op.intermediate_2[op.k+1:2*op.k,op.k+1:2*op.k] .= Jₖ
133-
134-
op.inverse_intermediate_1[1:2*op.k,1:2*op.k] .= inv(op.intermediate_1[1:2*op.k,1:2*op.k])
135-
op.inverse_intermediate_2[1:2*op.k,1:2*op.k] .= inv(op.intermediate_2[1:2*op.k,1:2*op.k])
136-
137-
op.sol[1:2*op.k] .= op.inverse_intermediate_1[1:2*op.k,1:2*op.k] * (op.inverse_intermediate_2[1:2*op.k,1:2*op.k] * op.sol[1:2*op.k])
142+
mul!(view(op.sol, 1:op.k), transpose(view(op.Yₖ, :, 1:op.k)), v)
143+
mul!(view(op.sol, op.k+1:2*op.k), transpose(view(op.Sₖ, :,1:op.k)), v)
144+
# scal!(op.α, view(op.sol, op.k+1:2*op.k)) # more allocation, slower
145+
view(op.sol, op.k+1:2*op.k) .*= op.α
146+
147+
# view(op.sol, 1:2*op.k) .= view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) * (view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) * view(op.sol, 1:2*op.k))
148+
mul!(view(op.intermediary_vector, 1:2*op.k), view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k), view(op.sol, 1:2*op.k))
149+
mul!(view(op.sol, 1:2*op.k), view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k), view(op.intermediary_vector, 1:2*op.k))
138150

139151
# step 7
140-
Bv .= op.α .* v .- (op.Yₖ[:,1:op.k] * op.sol[1:op.k] .+ op.α .* op.Sₖ[:,1:op.k] * op.sol[op.k+1:2*op.k])
141-
152+
# Bv .= op.α .* v .- (view(op.Yₖ, :,1:op.k) * view(op.sol, 1:op.k) .+ op.α .* view(op.Sₖ, :, 1:op.k) * view(op.sol, op.k+1:2*op.k))
153+
154+
mul!(Bv, view(op.Yₖ, :, 1:op.k), view(op.sol, 1:op.k))
155+
mul!(Bv, view(op.Sₖ, :, 1:op.k), view(op.sol, op.k+1:2*op.k), - op.α, (T)(-1))
156+
Bv .+= op.α .* v
142157
return Bv
143158
end

0 commit comments

Comments
 (0)