Skip to content

Commit 3284ced

Browse files
author
RAYNAUD Paul (raynaudp)
committed
circular push!
1 parent 4ee7ddc commit 3284ced

File tree

1 file changed

+44
-35
lines changed

1 file changed

+44
-35
lines changed

src/compressed_lbfgs.jl

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,67 +36,76 @@ default_gpu() = CUDA.functional() ? true : false
3636
default_matrix_type(gpu::Bool, T::DataType) = gpu ? CuMatrix{T} : Matrix{T}
3737
default_vector_type(gpu::Bool, T::DataType) = gpu ? CuVector{T} : Vector{T}
3838

39-
function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=default_gpu(), M=default_matrix_type(gpu,T), V=default_vector_type(gpu,T))
39+
function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=default_gpu(), M=default_matrix_type(gpu, T), V=default_vector_type(gpu, T))
4040
α = (T)(1)
4141
k = 0
42-
Sₖ = M(undef,n,m)
43-
Yₖ = M(undef,n,m)
44-
Dₖ = Diagonal(V(undef,m))
45-
Lₖ = LowerTriangular(M(undef,m,m))
46-
47-
chol_matrix = M(undef,m,m)
48-
intermediate_1 = UpperTriangular(M(undef,2*m,2*m))
49-
intermediate_2 = LowerTriangular(M(undef,2*m,2*m))
50-
inverse_intermediate_1 = UpperTriangular(M(undef,2*m,2*m))
51-
inverse_intermediate_2 = LowerTriangular(M(undef,2*m,2*m))
52-
intermediary_vector = V(undef,2*m)
53-
sol = V(undef,2*m)
42+
Sₖ = M(undef, n, m)
43+
Yₖ = M(undef, n, m)
44+
Dₖ = Diagonal(V(undef, m))
45+
Lₖ = LowerTriangular(M(undef, m, m))
46+
47+
chol_matrix = M(undef, m, m)
48+
intermediate_1 = UpperTriangular(M(undef, 2*m, 2*m))
49+
intermediate_2 = LowerTriangular(M(undef, 2*m, 2*m))
50+
inverse_intermediate_1 = UpperTriangular(M(undef, 2*m, 2*m))
51+
inverse_intermediate_2 = LowerTriangular(M(undef, 2*m, 2*m))
52+
intermediary_vector = V(undef, 2*m)
53+
sol = V(undef, 2*m)
5454
intermediate_structure_updated = false
5555
return CompressedLBFGS{T,M,V}(m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, intermediary_vector, sol, intermediate_structure_updated)
5656
end
5757

5858
function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:AbstractVector{T}}
5959
if op.k < op.m # still some place in structures
6060
op.k += 1
61-
op.Sₖ[:,op.k] .= s
62-
op.Yₖ[:,op.k] .= y
63-
op.Dₖ.diag[op.k] = dot(s,y)
61+
op.Sₖ[:, op.k] .= s
62+
op.Yₖ[:, op.k] .= y
63+
op.Dₖ.diag[op.k] = dot(s, y)
6464
op.Lₖ.data[op.k, op.k] = 0
6565
for i in 1:op.k-1
66-
op.Lₖ.data[op.k, i] = dot(s,op.Yₖ[:,i])
66+
# op.Lₖ.data[op.k, i] = dot(s, op.Yₖ[:, i])
67+
op.Lₖ.data[op.k, i] = dot(op.Sₖ[:, op.k], op.Yₖ[:, i])
6768
end
6869
# the secan equation fails if this line is uncommented
69-
# op.α = dot(y,s)/dot(s,s)
7070
else # update matrix with circular shift
71+
println("else")
7172
# must be tested
72-
circshift(op.Sₖ, (0,-1))
73-
circshift(op.Yₖ, (0,-1))
74-
circshift(op.Dₖ, (-1,-1))
73+
op.Sₖ .= circshift(op.Sₖ, (0, -1))
74+
op.Yₖ .= circshift(op.Yₖ, (0, -1))
75+
op.Dₖ .= circshift(op.Dₖ, (-1, -1))
76+
op.Sₖ[:, op.k] .= s
77+
op.Yₖ[:, op.k] .= y
78+
op.Dₖ.diag[op.k] = dot(s, y)
7579
# circshift doesn't work for a LowerTriangular matrix
7680
# for the time being, reinstantiate completely the Lₖ matrix
77-
for j in 2:op.k
81+
for j in 1:op.k
7882
for i in 1:j-1
79-
op.Lₖ.data[j, i] = dot(op.Sₖ[:,j],op.Yₖ[:,i])
83+
op.Lₖ.data[j, i] = dot(op.Sₖ[:, j], op.Yₖ[:, i])
8084
end
8185
end
8286
end
87+
@show op.Lₖ
88+
@show op.Sₖ
89+
@show op.Yₖ
90+
@show op.Dₖ
91+
# op.α = dot(y,s)/dot(s,s)
8392
op.intermediate_structure_updated = false
8493
return op
8594
end
8695

8796
# Theorem 2.3 (p6)
8897
function Base.Matrix(op::CompressedLBFGS{T,M,V}) where {T,M,V}
89-
B₀ = M(zeros(T,op.n, op.n))
90-
map(i -> B₀[i,i] = op.α, 1:op.n)
98+
B₀ = M(zeros(T, op.n, op.n))
99+
map(i -> B₀[i, i] = op.α, 1:op.n)
91100

92101
BSY = M(undef, op.n, 2*op.k)
93-
(op.k > 0) && (BSY[:,1:op.k] = B₀ * op.Sₖ[:,1:op.k])
94-
(op.k > 0) && (BSY[:,op.k+1:2*op.k] = op.Yₖ[:,1:op.k])
102+
(op.k > 0) && (BSY[:, 1:op.k] = B₀ * op.Sₖ[:, 1:op.k])
103+
(op.k > 0) && (BSY[:, op.k+1:2*op.k] = op.Yₖ[:, 1:op.k])
95104
_C = M(undef, 2*op.k, 2*op.k)
96-
(op.k > 0) && (_C[1:op.k, 1:op.k] .= transpose(op.Sₖ[:,1:op.k]) * op.Sₖ[:,1:op.k])
97-
(op.k > 0) && (_C[1:op.k, op.k+1:2*op.k] .= op.Lₖ[1:op.k,1:op.k])
98-
(op.k > 0) && (_C[op.k+1:2*op.k, 1:op.k] .= transpose(op.Lₖ[1:op.k,1:op.k]))
99-
(op.k > 0) && (_C[op.k+1:2*op.k, op.k+1:2*op.k] .= .- op.Dₖ[1:op.k,1:op.k])
105+
(op.k > 0) && (_C[1:op.k, 1:op.k] .= transpose(op.Sₖ[:, 1:op.k]) * op.Sₖ[:, 1:op.k])
106+
(op.k > 0) && (_C[1:op.k, op.k+1:2*op.k] .= op.Lₖ[1:op.k, 1:op.k])
107+
(op.k > 0) && (_C[op.k+1:2*op.k, 1:op.k] .= transpose(op.Lₖ[1:op.k, 1:op.k]))
108+
(op.k > 0) && (_C[op.k+1:2*op.k, op.k+1:2*op.k] .= .- op.Dₖ[1:op.k, 1:op.k])
100109
C = inv(_C)
101110

102111
Bₖ = B₀ .- BSY * C * transpose(BSY)
@@ -106,7 +115,7 @@ end
106115
# step 4, Jₖ is computed only if needed
107116
function inverse_cholesky(op::CompressedLBFGS)
108117
view(op.chol_matrix, 1:op.k, 1:op.k) .= op.α .* (transpose(view(op.Sₖ, :, 1:op.k)) * view(op.Sₖ, :, 1:op.k)) .+ view(op.Lₖ, 1:op.k, 1:op.k) * inv(op.Dₖ[1:op.k, 1:op.k]) * transpose(view(op.Lₖ, 1:op.k, 1:op.k))
109-
cholesky!(view(op.chol_matrix,1:op.k,1:op.k))
118+
cholesky!(Symmetric(view(op.chol_matrix, 1:op.k, 1:op.k)))
110119
Jₖ = transpose(UpperTriangular(view(op.chol_matrix, 1:op.k, 1:op.k)))
111120
return Jₖ
112121
end
@@ -125,8 +134,8 @@ function precompile_iterated_structure!(op::CompressedLBFGS)
125134
view(op.intermediate_2, op.k+1:2*op.k, 1:op.k) .= .- view(op.Lₖ, 1:op.k, 1:op.k) * view(op.Dₖ, 1:op.k, 1:op.k)^(-1/2)
126135
view(op.intermediate_2, op.k+1:2*op.k, op.k+1:2*op.k) .= Jₖ
127136

128-
view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_1[ 1:2*op.k,1:2*op.k])
129-
view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_2[ 1:2*op.k,1:2*op.k])
137+
view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_1[1:2*op.k, 1:2*op.k])
138+
view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_2[1:2*op.k, 1:2*op.k])
130139

131140
op.intermediate_structure_updated = true
132141
end
@@ -140,7 +149,7 @@ function LinearAlgebra.mul!(Bv::V, op::CompressedLBFGS{T,M,V}, v::V) where {T,M,
140149

141150
# step 5, try views for mul!
142151
mul!(view(op.sol, 1:op.k), transpose(view(op.Yₖ, :, 1:op.k)), v)
143-
mul!(view(op.sol, op.k+1:2*op.k), transpose(view(op.Sₖ, :,1:op.k)), v)
152+
mul!(view(op.sol, op.k+1:2*op.k), transpose(view(op.Sₖ, :, 1:op.k)), v)
144153
# scal!(op.α, view(op.sol, op.k+1:2*op.k)) # more allocation, slower
145154
view(op.sol, op.k+1:2*op.k) .*= op.α
146155

0 commit comments

Comments
 (0)