Skip to content

Commit 24cc4b9

Browse files
author
RAYNAUD Paul (raynaudp)
committed
CompressedLBFGS structure + Matrix interface + mul! (first version)
1 parent a204372 commit 24cc4b9

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

src/compressed_lbfgs.jl

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#=
2+
Compressed LBFGS implementation from:
3+
REPRESENTATIONS OF QUASI-NEWTON MATRICES AND THEIR USE IN LIMITED MEMORY METHODS
4+
Richard H. Byrd, Jorge Nocedal and Robert B. Schnabel (1994)
5+
DOI: 10.1007/BF01582063
6+
7+
Implemented by Paul Raynaud (supervised by Dominique Orban)
8+
=#
9+
10+
using LinearAlgebra
11+
12+
export CompressedLBFGS
13+
14+
mutable struct CompressedLBFGS{T, M<:AbstractMatrix{T}, V<:AbstractVector{T}}
15+
m::Int # memory of the operator
16+
n::Int # vector size
17+
k::Int # k ≤ m, active memory of the operator
18+
α::T # B₀ = αI
19+
Sₖ::M # gather all sₖ₋ₘ
20+
Yₖ::M # gather all yₖ₋ₘ
21+
Dₖ::Diagonal{T,V} # m * m
22+
Lₖ::LowerTriangular{T,M} # m * m
23+
24+
chol_matrix::M # 2m * 2m
25+
intermediate_1::UpperTriangular{T,M} # 2m * 2m
26+
intermediate_2::LowerTriangular{T,M} # 2m * 2m
27+
inverse_intermediate_1::UpperTriangular{T,M} # 2m * 2m
28+
inverse_intermediate_2::LowerTriangular{T,M} # 2m * 2m
29+
sol::V # m
30+
inverse::Bool
31+
end
32+
33+
default_matrix_type(gpu::Bool, T::DataType) = gpu ? CuMatrix{T} : Matrix{T}
34+
default_vector_type(gpu::Bool, T::DataType) = gpu ? CuVector{T} : Vector{T}
35+
36+
function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=false, M=default_matrix_type(gpu,T), V=default_vector_type(gpu,T))
37+
α = (T)(1)
38+
k = 0
39+
Sₖ = M(undef,n,m)
40+
Yₖ = M(undef,n,m)
41+
Dₖ = Diagonal(V(undef,m))
42+
Lₖ = LowerTriangular(M(undef,m,m))
43+
44+
chol_matrix = M(undef,m,m)
45+
intermediate_1 = UpperTriangular(M(undef,2*m,2*m))
46+
intermediate_2 = LowerTriangular(M(undef,2*m,2*m))
47+
inverse_intermediate_1 = UpperTriangular(M(undef,2*m,2*m))
48+
inverse_intermediate_2 = LowerTriangular(M(undef,2*m,2*m))
49+
sol = V(undef,2*m)
50+
inverse = false
51+
return CompressedLBFGS{T,M,V}(m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, sol, inverse)
52+
end
53+
54+
function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:AbstractVector{T}}
55+
if op.k < op.m # still some place in structures
56+
op.k += 1
57+
op.Sₖ[:,op.k] .= s
58+
op.Yₖ[:,op.k] .= y
59+
op.Dₖ.diag[op.k] = dot(s,y)
60+
op.Lₖ.data[op.k, op.k] = 0
61+
for i in 1:op.k-1
62+
op.Lₖ.data[op.k, i] = dot(s,op.Yₖ[:,i])
63+
end
64+
# the secan equation fails if this line is uncommented
65+
# op.α = dot(y,s)/dot(s,s)
66+
else # update matrix with circular shift
67+
# must be tested
68+
circshift(op.Sₖ, (0,-1))
69+
circshift(op.Yₖ, (0,-1))
70+
circshift(op.Dₖ, (-1,-1))
71+
# circshift doesn't work for a LowerTriangular matrix
72+
for j in 2:op.k
73+
for i in 1:j-1
74+
op.Lₖ.data[j, i] = dot(op.Sₖ[:,j],op.Yₖ[:,i])
75+
end
76+
end
77+
end
78+
op.inverse = false
79+
return op
80+
end
81+
82+
# Theorem 2.3 (p6)
83+
function Base.Matrix(op::CompressedLBFGS{T,M,V}) where {T,M,V}
84+
B₀ = M(zeros(T,op.n, op.n))
85+
map(i -> B₀[i,i] = op.α, 1:op.n)
86+
87+
BSY = M(undef, op.n, 2*op.k)
88+
(op.k > 0) && (BSY[:,1:op.k] = B₀ * op.Sₖ[:,1:op.k])
89+
(op.k > 0) && (BSY[:,op.k+1:2*op.k] = op.Yₖ[:,1:op.k])
90+
_C = M(undef, 2*op.k, 2*op.k)
91+
(op.k > 0) && (_C[1:op.k, 1:op.k] .= transpose(op.Sₖ[:,1:op.k]) * op.Sₖ[:,1:op.k])
92+
(op.k > 0) && (_C[1:op.k, op.k+1:2*op.k] .= op.Lₖ[1:op.k,1:op.k])
93+
(op.k > 0) && (_C[op.k+1:2*op.k, 1:op.k] .= transpose(op.Lₖ[1:op.k,1:op.k]))
94+
(op.k > 0) && (_C[op.k+1:2*op.k, op.k+1:2*op.k] .= .- op.Dₖ[1:op.k,1:op.k])
95+
C = inv(_C)
96+
97+
Bₖ = B₀ .- BSY * C * transpose(BSY)
98+
return Bₖ
99+
end
100+
101+
function inverse_cholesky(op::CompressedLBFGS)
102+
if !op.inverse
103+
op.chol_matrix[1:op.k,1:op.k] .= op.α .* (transpose(op.Sₖ[:,1:op.k]) * op.Sₖ[:,1:op.k]) .+ op.Lₖ[1:op.k,1:op.k] * inv(op.Dₖ[1:op.k,1:op.k]) * transpose(op.Lₖ[1:op.k,1:op.k])
104+
cholesky!(view(op.chol_matrix,1:op.k,1:op.k))
105+
op.inverse = true
106+
end
107+
Jₖ = transpose(UpperTriangular(op.chol_matrix[1:op.k,1:op.k]))
108+
return Jₖ
109+
end
110+
111+
# Algorithm 3.2 (p15)
112+
function LinearAlgebra.mul!(Bv::V, op::CompressedLBFGS{T,M,V}, v::V) where {T,M,V<:AbstractVector{T}}
113+
# step 1-3 mainly done by Base.push!
114+
# step 4, Jₖ is computed only if needed
115+
Jₖ = inverse_cholesky(op::CompressedLBFGS)
116+
117+
# step 5, try views for mul!
118+
# mul!(op.sol[1:op.k], transpose(op.Yₖ[:,1:op.k]), v) # wrong result
119+
# mul!(op.sol[op.k+1:2*op.k], transpose(op.Yₖ[:,1:op.k]), v, (T)(1), op.α) # wrong result
120+
op.sol[1:op.k] .= transpose(op.Yₖ[:,1:op.k]) * v
121+
op.sol[op.k+1:2*op.k] .= op.α .* transpose(op.Sₖ[:,1:op.k]) * v
122+
123+
# step 6, must be improve
124+
op.intermediate_1[1:op.k,1:op.k] .= .- op.Dₖ[1:op.k,1:op.k]^(1/2)
125+
op.intermediate_1[1:op.k,op.k+1:2*op.k] .= op.Dₖ[1:op.k,1:op.k]^(-1/2) * transpose(op.Lₖ[1:op.k,1:op.k])
126+
op.intermediate_1[op.k+1:2*op.k,1:op.k] .= 0
127+
op.intermediate_1[op.k+1:2*op.k,op.k+1:2*op.k] .= transpose(Jₖ)
128+
129+
op.intermediate_2[1:op.k,1:op.k] .= op.Dₖ[1:op.k,1:op.k]^(1/2)
130+
op.intermediate_2[1:op.k,op.k+1:2*op.k] .= 0
131+
op.intermediate_2[op.k+1:2*op.k,1:op.k] .= .- op.Lₖ[1:op.k,1:op.k] * op.Dₖ[1:op.k,1:op.k]^(-1/2)
132+
op.intermediate_2[op.k+1:2*op.k,op.k+1:2*op.k] .= Jₖ
133+
134+
op.inverse_intermediate_1[1:2*op.k,1:2*op.k] .= inv(op.intermediate_1[1:2*op.k,1:2*op.k])
135+
op.inverse_intermediate_2[1:2*op.k,1:2*op.k] .= inv(op.intermediate_2[1:2*op.k,1:2*op.k])
136+
137+
op.sol[1:2*op.k] .= op.inverse_intermediate_1[1:2*op.k,1:2*op.k] * (op.inverse_intermediate_2[1:2*op.k,1:2*op.k] * op.sol[1:2*op.k])
138+
139+
# step 7
140+
Bv .= op.α .* v .- (op.Yₖ[:,1:op.k] * op.sol[1:op.k] .+ op.α .* op.Sₖ[:,1:op.k] * op.sol[op.k+1:2*op.k])
141+
142+
return Bv
143+
end

src/qn.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ import LinearAlgebra.diag
55

66
include("lbfgs.jl")
77
include("lsr1.jl")
8+
9+
include("compressed_lbfgs.jl")

0 commit comments

Comments
 (0)