@@ -7,7 +7,8 @@ Compressed LBFGS implementation from:
7
7
Implemented by Paul Raynaud (supervised by Dominique Orban)
8
8
=#
9
9
10
- using LinearAlgebra
10
+ using LinearAlgebra, LinearAlgebra. BLAS
11
+ using CUDA
11
12
12
13
export CompressedLBFGS
13
14
@@ -26,14 +27,16 @@ mutable struct CompressedLBFGS{T, M<:AbstractMatrix{T}, V<:AbstractVector{T}}
26
27
intermediate_2:: LowerTriangular{T,M} # 2m * 2m
27
28
inverse_intermediate_1:: UpperTriangular{T,M} # 2m * 2m
28
29
inverse_intermediate_2:: LowerTriangular{T,M} # 2m * 2m
30
+ intermediary_vector:: V # 2m
29
31
sol:: V # m
30
- inverse :: Bool
32
+ intermediate_structure_updated :: Bool
31
33
end
32
34
35
+ default_gpu () = CUDA. functional () ? true : false
33
36
default_matrix_type (gpu:: Bool , T:: DataType ) = gpu ? CuMatrix{T} : Matrix{T}
34
37
default_vector_type (gpu:: Bool , T:: DataType ) = gpu ? CuVector{T} : Vector{T}
35
38
36
- function CompressedLBFGS (m:: Int , n:: Int ; T= Float64, gpu= false , M= default_matrix_type (gpu,T), V= default_vector_type (gpu,T))
39
+ function CompressedLBFGS (m:: Int , n:: Int ; T= Float64, gpu= default_gpu () , M= default_matrix_type (gpu,T), V= default_vector_type (gpu,T))
37
40
α = (T)(1 )
38
41
k = 0
39
42
Sₖ = M (undef,n,m)
@@ -46,9 +49,10 @@ function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=false, M=default_matrix_
46
49
intermediate_2 = LowerTriangular (M (undef,2 * m,2 * m))
47
50
inverse_intermediate_1 = UpperTriangular (M (undef,2 * m,2 * m))
48
51
inverse_intermediate_2 = LowerTriangular (M (undef,2 * m,2 * m))
52
+ intermediary_vector = V (undef,2 * m)
49
53
sol = V (undef,2 * m)
50
- inverse = false
51
- return CompressedLBFGS {T,M,V} (m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, sol, inverse )
54
+ intermediate_structure_updated = false
55
+ return CompressedLBFGS {T,M,V} (m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, intermediary_vector, sol, intermediate_structure_updated )
52
56
end
53
57
54
58
function Base. push! (op:: CompressedLBFGS{T,M,V} , s:: V , y:: V ) where {T,M,V<: AbstractVector{T} }
@@ -69,13 +73,14 @@ function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:Abstra
69
73
circshift (op. Yₖ, (0 ,- 1 ))
70
74
circshift (op. Dₖ, (- 1 ,- 1 ))
71
75
# circshift doesn't work for a LowerTriangular matrix
76
+ # for the time being, reinstantiate completely the Lₖ matrix
72
77
for j in 2 : op. k
73
78
for i in 1 : j- 1
74
79
op. Lₖ. data[j, i] = dot (op. Sₖ[:,j],op. Yₖ[:,i])
75
80
end
76
81
end
77
82
end
78
- op. inverse = false
83
+ op. intermediate_structure_updated = false
79
84
return op
80
85
end
81
86
@@ -98,46 +103,56 @@ function Base.Matrix(op::CompressedLBFGS{T,M,V}) where {T,M,V}
98
103
return Bₖ
99
104
end
100
105
106
+ # step 4, Jₖ is computed only if needed
101
107
function inverse_cholesky (op:: CompressedLBFGS )
102
- if ! op. inverse
103
- op. chol_matrix[1 : op. k,1 : op. k] .= op. α .* (transpose (op. Sₖ[:,1 : op. k]) * op. Sₖ[:,1 : op. k]) .+ op. Lₖ[1 : op. k,1 : op. k] * inv (op. Dₖ[1 : op. k,1 : op. k]) * transpose (op. Lₖ[1 : op. k,1 : op. k])
104
- cholesky! (view (op. chol_matrix,1 : op. k,1 : op. k))
105
- op. inverse = true
106
- end
107
- Jₖ = transpose (UpperTriangular (op. chol_matrix[1 : op. k,1 : op. k]))
108
+ view (op. chol_matrix, 1 : op. k, 1 : op. k) .= op. α .* (transpose (view (op. Sₖ, :, 1 : op. k)) * view (op. Sₖ, :, 1 : op. k)) .+ view (op. Lₖ, 1 : op. k, 1 : op. k) * inv (op. Dₖ[1 : op. k, 1 : op. k]) * transpose (view (op. Lₖ, 1 : op. k, 1 : op. k))
109
+ cholesky! (view (op. chol_matrix,1 : op. k,1 : op. k))
110
+ Jₖ = transpose (UpperTriangular (view (op. chol_matrix, 1 : op. k, 1 : op. k)))
108
111
return Jₖ
109
112
end
110
113
114
+ # step 6, must be improve
115
+ function precompile_iterated_structure! (op:: CompressedLBFGS )
116
+ Jₖ = inverse_cholesky (op)
117
+
118
+ view (op. intermediate_1, 1 : op. k,1 : op. k) .= .- view (op. Dₖ, 1 : op. k, 1 : op. k)^ (1 / 2 )
119
+ view (op. intermediate_1, 1 : op. k,op. k+ 1 : 2 * op. k) .= view (op. Dₖ, 1 : op. k, 1 : op. k)^ (- 1 / 2 ) * transpose (view (op. Lₖ, 1 : op. k, 1 : op. k))
120
+ view (op. intermediate_1, op. k+ 1 : 2 * op. k, 1 : op. k) .= 0
121
+ view (op. intermediate_1, op. k+ 1 : 2 * op. k, op. k+ 1 : 2 * op. k) .= transpose (Jₖ)
122
+
123
+ view (op. intermediate_2, 1 : op. k, 1 : op. k) .= view (op. Dₖ, 1 : op. k, 1 : op. k)^ (1 / 2 )
124
+ view (op. intermediate_2, 1 : op. k, op. k+ 1 : 2 * op. k) .= 0
125
+ view (op. intermediate_2, op. k+ 1 : 2 * op. k, 1 : op. k) .= .- view (op. Lₖ, 1 : op. k, 1 : op. k) * view (op. Dₖ, 1 : op. k, 1 : op. k)^ (- 1 / 2 )
126
+ view (op. intermediate_2, op. k+ 1 : 2 * op. k, op. k+ 1 : 2 * op. k) .= Jₖ
127
+
128
+ view (op. inverse_intermediate_1, 1 : 2 * op. k, 1 : 2 * op. k) .= inv (op. intermediate_1[ 1 : 2 * op. k,1 : 2 * op. k])
129
+ view (op. inverse_intermediate_2, 1 : 2 * op. k, 1 : 2 * op. k) .= inv (op. intermediate_2[ 1 : 2 * op. k,1 : 2 * op. k])
130
+
131
+ op. intermediate_structure_updated = true
132
+ end
133
+
111
134
# Algorithm 3.2 (p15)
112
135
function LinearAlgebra. mul! (Bv:: V , op:: CompressedLBFGS{T,M,V} , v:: V ) where {T,M,V<: AbstractVector{T} }
113
136
# step 1-3 mainly done by Base.push!
114
- # step 4, Jₖ is computed only if needed
115
- Jₖ = inverse_cholesky (op:: CompressedLBFGS )
137
+
138
+ # steps 4 and 6, in case the intermediary required structure are not up to date
139
+ (! op. intermediate_structure_updated) && (precompile_iterated_structure! (op))
116
140
117
141
# step 5, try views for mul!
118
- # mul!(op.sol[1:op.k], transpose(op.Yₖ[:,1:op.k]), v) # wrong result
119
- # mul!(op.sol[op.k+1:2*op.k], transpose(op.Yₖ[:,1:op.k]), v, (T)(1), op.α) # wrong result
120
- op. sol[1 : op. k] .= transpose (op. Yₖ[:,1 : op. k]) * v
121
- op. sol[op. k+ 1 : 2 * op. k] .= op. α .* transpose (op. Sₖ[:,1 : op. k]) * v
122
-
123
- # step 6, must be improve
124
- op. intermediate_1[1 : op. k,1 : op. k] .= .- op. Dₖ[1 : op. k,1 : op. k]^ (1 / 2 )
125
- op. intermediate_1[1 : op. k,op. k+ 1 : 2 * op. k] .= op. Dₖ[1 : op. k,1 : op. k]^ (- 1 / 2 ) * transpose (op. Lₖ[1 : op. k,1 : op. k])
126
- op. intermediate_1[op. k+ 1 : 2 * op. k,1 : op. k] .= 0
127
- op. intermediate_1[op. k+ 1 : 2 * op. k,op. k+ 1 : 2 * op. k] .= transpose (Jₖ)
128
-
129
- op. intermediate_2[1 : op. k,1 : op. k] .= op. Dₖ[1 : op. k,1 : op. k]^ (1 / 2 )
130
- op. intermediate_2[1 : op. k,op. k+ 1 : 2 * op. k] .= 0
131
- op. intermediate_2[op. k+ 1 : 2 * op. k,1 : op. k] .= .- op. Lₖ[1 : op. k,1 : op. k] * op. Dₖ[1 : op. k,1 : op. k]^ (- 1 / 2 )
132
- op. intermediate_2[op. k+ 1 : 2 * op. k,op. k+ 1 : 2 * op. k] .= Jₖ
133
-
134
- op. inverse_intermediate_1[1 : 2 * op. k,1 : 2 * op. k] .= inv (op. intermediate_1[1 : 2 * op. k,1 : 2 * op. k])
135
- op. inverse_intermediate_2[1 : 2 * op. k,1 : 2 * op. k] .= inv (op. intermediate_2[1 : 2 * op. k,1 : 2 * op. k])
136
-
137
- op. sol[1 : 2 * op. k] .= op. inverse_intermediate_1[1 : 2 * op. k,1 : 2 * op. k] * (op. inverse_intermediate_2[1 : 2 * op. k,1 : 2 * op. k] * op. sol[1 : 2 * op. k])
142
+ mul! (view (op. sol, 1 : op. k), transpose (view (op. Yₖ, :, 1 : op. k)), v)
143
+ mul! (view (op. sol, op. k+ 1 : 2 * op. k), transpose (view (op. Sₖ, :,1 : op. k)), v)
144
+ # scal!(op.α, view(op.sol, op.k+1:2*op.k)) # more allocation, slower
145
+ view (op. sol, op. k+ 1 : 2 * op. k) .*= op. α
146
+
147
+ # view(op.sol, 1:2*op.k) .= view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) * (view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) * view(op.sol, 1:2*op.k))
148
+ mul! (view (op. intermediary_vector, 1 : 2 * op. k), view (op. inverse_intermediate_2, 1 : 2 * op. k, 1 : 2 * op. k), view (op. sol, 1 : 2 * op. k))
149
+ mul! (view (op. sol, 1 : 2 * op. k), view (op. inverse_intermediate_1, 1 : 2 * op. k, 1 : 2 * op. k), view (op. intermediary_vector, 1 : 2 * op. k))
138
150
139
151
# step 7
140
- Bv .= op. α .* v .- (op. Yₖ[:,1 : op. k] * op. sol[1 : op. k] .+ op. α .* op. Sₖ[:,1 : op. k] * op. sol[op. k+ 1 : 2 * op. k])
141
-
152
+ # Bv .= op.α .* v .- (view(op.Yₖ, :,1:op.k) * view(op.sol, 1:op.k) .+ op.α .* view(op.Sₖ, :, 1:op.k) * view(op.sol, op.k+1:2*op.k))
153
+
154
+ mul! (Bv, view (op. Yₖ, :, 1 : op. k), view (op. sol, 1 : op. k))
155
+ mul! (Bv, view (op. Sₖ, :, 1 : op. k), view (op. sol, op. k+ 1 : 2 * op. k), - op. α, (T)(- 1 ))
156
+ Bv .+ = op. α .* v
142
157
return Bv
143
158
end
0 commit comments