Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function ADModelBackend(
hessian_backend = if hessian_backend isa Union{AbstractNLPModel, ADBackend}
hessian_backend
else
HB(nvar, f, ncon, c!; kwargs...)
HB(nvar, f, ncon, c!; show_time, kwargs...)
end
end
show_time && println("hessian backend $HB: $b seconds;")
Expand Down Expand Up @@ -170,7 +170,7 @@ function ADModelBackend(
jacobian_backend = if jacobian_backend isa Union{AbstractNLPModel, ADBackend}
jacobian_backend
else
JB(nvar, f, ncon, c!; kwargs...)
JB(nvar, f, ncon, c!; show_time, kwargs...)
end
end
show_time && println("jacobian backend $JB: $b seconds;")
Expand All @@ -180,7 +180,7 @@ function ADModelBackend(
hessian_backend = if hessian_backend isa Union{AbstractNLPModel, ADBackend}
hessian_backend
else
HB(nvar, f, ncon, c!; kwargs...)
HB(nvar, f, ncon, c!; show_time, kwargs...)
end
end
show_time && println("hessian backend $HB: $b seconds;")
Expand Down Expand Up @@ -263,7 +263,7 @@ function ADModelNLSBackend(
hessian_backend = if hessian_backend isa Union{AbstractNLPModel, ADBackend}
hessian_backend
else
HB(nvar, f, ncon, c!; kwargs...)
HB(nvar, f, ncon, c!; show_time, kwargs...)
end
end
show_time && println("hessian backend $HB: $b seconds;")
Expand Down Expand Up @@ -304,7 +304,7 @@ function ADModelNLSBackend(
if jacobian_residual_backend isa Union{AbstractNLPModel, ADBackend}
jacobian_residual_backend
else
JBLS(nvar, x -> zero(eltype(x)), nequ, F!; kwargs...)
JBLS(nvar, x -> zero(eltype(x)), nequ, F!; show_time, kwargs...)
end
end
show_time && println("jacobian_residual backend $JBLS: $b seconds;")
Expand All @@ -314,7 +314,7 @@ function ADModelNLSBackend(
hessian_residual_backend = if hessian_residual_backend isa Union{AbstractNLPModel, ADBackend}
hessian_residual_backend
else
HBLS(nvar, x -> zero(eltype(x)), nequ, F!; kwargs...)
HBLS(nvar, x -> zero(eltype(x)), nequ, F!; show_time, kwargs...)
end
end
show_time && println("hessian_residual backend $HBLS: $b seconds. \n")
Expand Down Expand Up @@ -410,7 +410,7 @@ function ADModelNLSBackend(
jacobian_backend = if jacobian_backend isa Union{AbstractNLPModel, ADBackend}
jacobian_backend
else
JB(nvar, f, ncon, c!; kwargs...)
JB(nvar, f, ncon, c!; show_time, kwargs...)
end
end
show_time && println("jacobian backend $JB: $b seconds;")
Expand All @@ -420,7 +420,7 @@ function ADModelNLSBackend(
hessian_backend = if hessian_backend isa Union{AbstractNLPModel, ADBackend}
hessian_backend
else
HB(nvar, f, ncon, c!; kwargs...)
HB(nvar, f, ncon, c!; show_time, kwargs...)
end
end
show_time && println("hessian backend $HB: $b seconds;")
Expand Down Expand Up @@ -471,7 +471,7 @@ function ADModelNLSBackend(
if jacobian_residual_backend isa Union{AbstractNLPModel, ADBackend}
jacobian_residual_backend
else
JBLS(nvar, x -> zero(eltype(x)), nequ, F!; kwargs...)
JBLS(nvar, x -> zero(eltype(x)), nequ, F!; show_time, kwargs...)
end
end
show_time && println("jacobian_residual backend $JBLS: $b seconds;")
Expand All @@ -481,7 +481,7 @@ function ADModelNLSBackend(
hessian_residual_backend = if hessian_residual_backend isa Union{AbstractNLPModel, ADBackend}
hessian_residual_backend
else
HBLS(nvar, x -> zero(eltype(x)), nequ, F!; kwargs...)
HBLS(nvar, x -> zero(eltype(x)), nequ, F!; show_time, kwargs...)
end
end
show_time && println("hessian_residual backend $HBLS: $b seconds. \n")
Expand Down
108 changes: 67 additions & 41 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ function SparseEnzymeADJacobian(
x0::AbstractVector = rand(nvar),
coloring_algorithm::AbstractColoringAlgorithm = GreedyColoringAlgorithm{:direct}(),
detector::AbstractSparsityDetector = TracerSparsityDetector(),
show_time::Bool = false,
kwargs...,
)
output = similar(x0, ncon)
J = compute_jacobian_sparsity(c!, output, x0, detector = detector)
SparseEnzymeADJacobian(nvar, f, ncon, c!, J; x0, coloring_algorithm, kwargs...)
timer = @elapsed begin
output = similar(x0, ncon)
J = compute_jacobian_sparsity(c!, output, x0, detector = detector)
end
show_time && println(" • Sparsity pattern detection of the Jacobian: $timer seconds.")
SparseEnzymeADJacobian(nvar, f, ncon, c!, J; x0, coloring_algorithm, show_time, kwargs...)
end

function SparseEnzymeADJacobian(
Expand All @@ -127,18 +131,26 @@ function SparseEnzymeADJacobian(
J::SparseMatrixCSC{Bool, Int};
x0::AbstractVector{T} = rand(nvar),
coloring_algorithm::AbstractColoringAlgorithm = GreedyColoringAlgorithm{:direct}(),
show_time::Bool = false,
kwargs...,
) where {T}
# We should support :row and :bidirectional in the future
problem = ColoringProblem{:nonsymmetric, :column}()
result_coloring = coloring(J, problem, coloring_algorithm, decompression_eltype = T)
timer = @elapsed begin
# We should support :row and :bidirectional in the future
problem = ColoringProblem{:nonsymmetric, :column}()
result_coloring = coloring(J, problem, coloring_algorithm, decompression_eltype = T)

rowval = J.rowval
colptr = J.colptr
nzval = T.(J.nzval)
compressed_jacobian = similar(x0, ncon)
end
show_time && println(" • Coloring of the sparse Jacobian: $timer seconds.")

rowval = J.rowval
colptr = J.colptr
nzval = T.(J.nzval)
compressed_jacobian = similar(x0, ncon)
v = similar(x0)
cx = zeros(T, ncon)
timer = @elapsed begin
v = similar(x0)
cx = zeros(T, ncon)
end
show_time && println(" • Allocation of the AD buffers for the sparse Jacobian: $timer seconds.")

SparseEnzymeADJacobian(
nvar,
Expand Down Expand Up @@ -177,10 +189,14 @@ function SparseEnzymeADHessian(
x0::AbstractVector = rand(nvar),
coloring_algorithm::AbstractColoringAlgorithm = GreedyColoringAlgorithm{:substitution}(),
detector::AbstractSparsityDetector = TracerSparsityDetector(),
show_time::Bool = false,
kwargs...,
)
H = compute_hessian_sparsity(f, nvar, c!, ncon, detector = detector)
SparseEnzymeADHessian(nvar, f, ncon, c!, H; x0, coloring_algorithm, kwargs...)
timer = @elapsed begin
H = compute_hessian_sparsity(f, nvar, c!, ncon, detector = detector)
end
show_time && println(" • Sparsity pattern detection of the Hessian: $timer seconds.")
SparseEnzymeADHessian(nvar, f, ncon, c!, H; x0, coloring_algorithm, show_time, kwargs...)
end

function SparseEnzymeADHessian(
Expand All @@ -191,38 +207,48 @@ function SparseEnzymeADHessian(
H::SparseMatrixCSC{Bool, Int};
x0::AbstractVector{T} = rand(nvar),
coloring_algorithm::AbstractColoringAlgorithm = GreedyColoringAlgorithm{:substitution}(),
show_time::Bool = false,
kwargs...,
) where {T}
problem = ColoringProblem{:symmetric, :column}()
result_coloring = coloring(H, problem, coloring_algorithm, decompression_eltype = T)

trilH = tril(H)
rowval = trilH.rowval
colptr = trilH.colptr
nzval = T.(trilH.nzval)
if coloring_algorithm isa GreedyColoringAlgorithm{:direct}
coloring_mode = :direct
compressed_hessian_icol = similar(x0)
compressed_hessian = compressed_hessian_icol
else
coloring_mode = :substitution
group = column_groups(result_coloring)
ncolors = length(group)
compressed_hessian_icol = similar(x0)
compressed_hessian = similar(x0, (nvar, ncolors))
timer = @elapsed begin
problem = ColoringProblem{:symmetric, :column}()
result_coloring = coloring(H, problem, coloring_algorithm, decompression_eltype = T)

trilH = tril(H)
rowval = trilH.rowval
colptr = trilH.colptr
nzval = T.(trilH.nzval)
if coloring_algorithm isa GreedyColoringAlgorithm{:direct}
coloring_mode = :direct
compressed_hessian_icol = similar(x0)
compressed_hessian = compressed_hessian_icol
else
coloring_mode = :substitution
group = column_groups(result_coloring)
ncolors = length(group)
compressed_hessian_icol = similar(x0)
compressed_hessian = similar(x0, (nvar, ncolors))
end
end
v = similar(x0)
y = similar(x0, ncon)
cx = similar(x0, ncon)
grad = similar(x0)
function ℓ(x, y, obj_weight, cx)
res = obj_weight * f(x)
if ncon != 0
c!(cx, x)
res += sum(cx[i] * y[i] for i = 1:ncon)
show_time && println(" • Coloring of the sparse Hessian: $timer seconds.")

timer = @elapsed begin
v = similar(x0)
y = similar(x0, ncon)
cx = similar(x0, ncon)
grad = similar(x0)

function ℓ(x, y, obj_weight, cx)
res = obj_weight * f(x)
if ncon != 0
c!(cx, x)
res += sum(cx[i] * y[i] for i = 1:ncon)
end
return res
end
return res
end
show_time && println(" • Allocation of the AD buffers for the sparse Hessian: $timer seconds.")


return SparseEnzymeADHessian(
nvar,
Expand Down
Loading
Loading