Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 20 additions & 179 deletions paper/examples/Benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# ======== IMPORTS ======== #
#############################
using Random, LinearAlgebra
using ProximalOperators, ProximalCore, ProximalAlgorithms
using ShiftedProximalOperators
using NLPModels, NLPModelsModifiers
using RegularizedOptimization, RegularizedProblems
Expand All @@ -14,7 +13,7 @@ using LaTeXStrings

# Local includes
include("comparison-config.jl")
using .ComparisonConfig: CFG, CFG2
using .ComparisonConfig: CFG

#############################
# ===== Helper utils ====== #
Expand Down Expand Up @@ -58,7 +57,7 @@ function run_tr_svm!(model, x0; λ = 1.0, qn = :LSR1, atol = 1e-3, rtol = 1e-3,
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "TR ($(String(qn)), SVM)",
name = "TR",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
Expand All @@ -84,7 +83,7 @@ function run_r2n_svm!(model, x0; λ = 1.0, qn = :LBFGS, atol = 1e-3, rtol = 1e-3
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "R2N ($(String(qn)), SVM)",
name = "R2N",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
Expand All @@ -108,7 +107,7 @@ function run_LM_svm!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbose
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "LM (SVM)",
name = "LM",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
Expand All @@ -132,7 +131,7 @@ function run_LMTR_svm!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbos
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "LMTR (SVM)",
name = "LMTR",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
Expand All @@ -159,18 +158,14 @@ function bench_svm!(cfg = CFG)
println("\n=== SVM: solver comparison ===")
for m in results
println("\n→ ", m.name)
println(" status = ", m.status)
println(" time (s) = ", round(m.time, digits = 4))
m.iters !== missing && println(" outer iters = ", m.iters)
println(" # f eval = ", m.fevals)
println(" # ∇f eval = ", m.gevals)
m.proxcalls !== missing && println(" # prox calls = ", Int(m.proxcalls))
println(" final objective= ", round(obj(model, m.solution), digits = 4))
println(" accuracy (%) = ", round(acc(residual(nls_train, m.solution)), digits = 1))
println(" status = ", m.status)
println(" time (s) = ", round(m.time, digits = 4))
println(" # f eval = ", m.fevals)
println(" # ∇f eval = ", m.gevals)
m.proxcalls !== missing && println(" # prox calls = ", Int(m.proxcalls))
println(" final objective = ", round(obj(model, m.solution), digits = 4))
end

println("\nSVM Config:"); print_config(cfg)

data_svm = [
(; name=m.name,
status=string(m.status),
Expand All @@ -185,182 +180,28 @@ function bench_svm!(cfg = CFG)
return data_svm
end

#############################
# ======= NNMF bench ====== #
#############################

function run_tr_nnmf!(model, x0; λ = 1.0, qn = :LSR1, atol = 1e-3, rtol = 1e-3, verbose = 0, sub_kwargs = (;), selected = nothing)
qn_model = ensure_qn(model, qn)
reset!(qn_model)
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected)
solver = TRSolver(reg_nlp)
stats = RegularizedExecutionStats(reg_nlp)
RegularizedOptimization.solve!(solver, reg_nlp, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
reset!(qn_model) # Reset counters before timing
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected) # Re-create to reset prox eval count
solver = TRSolver(reg_nlp)
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "TR ($(String(qn)), NNMF)",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
fevals = neval_obj(qn_model),
gevals = neval_grad(qn_model),
proxcalls = get(stats.solver_specific, :prox_evals, missing),
solution = stats.solution,
final_obj = obj(model, stats.solution)
)
end

function run_r2n_nnmf!(model, x0; λ = 1.0, qn = :LBFGS, atol = 1e-3, rtol = 1e-3, verbose = 0, sub_kwargs = (;), σk = 1e5, selected = nothing)
qn_model = ensure_qn(model, qn)
reset!(qn_model)
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected)
solver = R2NSolver(reg_nlp)
stats = RegularizedExecutionStats(reg_nlp)
RegularizedOptimization.solve!(solver, reg_nlp, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose,
sub_kwargs = sub_kwargs)

reset!(qn_model) # Reset counters before timing
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected) # Re-create to reset prox eval count
solver = R2NSolver(reg_nlp)
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose,
sub_kwargs = sub_kwargs)
return (
name = "R2N ($(String(qn)), NNMF)",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
fevals = neval_obj(qn_model),
gevals = neval_grad(qn_model),
proxcalls = get(stats.solver_specific, :prox_evals, missing),
solution = stats.solution,
final_obj = obj(model, stats.solution)
)
end

function run_LM_nnmf!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbose = 0, selected = nothing, sub_kwargs = (;))
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
solver = LMSolver(reg_nls)
stats = RegularizedExecutionStats(reg_nls)
RegularizedOptimization.solve!(solver, reg_nls, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
reset!(nls_model) # Reset counters before timing
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
solver = LMSolver(reg_nls)
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "LM (NNMF)",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
fevals = neval_residual(nls_model),
gevals = neval_jtprod_residual(nls_model) + neval_jprod_residual(nls_model),
proxcalls = get(stats.solver_specific, :prox_evals, missing),
solution = stats.solution,
final_obj = obj(nls_model, stats.solution)
)
end

function run_LMTR_nnmf!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbose = 0, selected = nothing, sub_kwargs = (;))
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
solver = LMTRSolver(reg_nls)
stats = RegularizedExecutionStats(reg_nls)
RegularizedOptimization.solve!(solver, reg_nls, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
reset!(nls_model) # Reset counters before timing
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
solver = LMTRSolver(reg_nls)
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
return (
name = "LMTR (NNMF)",
status = string(stats.status),
time = t,
iters = get(stats.solver_specific, :outer_iter, missing),
fevals = neval_residual(nls_model),
gevals = neval_jtprod_residual(nls_model) + neval_jprod_residual(nls_model),
proxcalls = get(stats.solver_specific, :prox_evals, missing),
solution = stats.solution,
final_obj = obj(nls_model, stats.solution)
)
end

function bench_nnmf!(cfg = CFG2; m = 100, n = 50, k = 5)
Random.seed!(cfg.SEED)

model, nls_model, _, selected = nnmf_model(m, n, k)

# build x0 on positive orthant as original
x0 = max.(rand(model.meta.nvar), 0.0)

# heuristic lambda (copied logic)
cfg.LAMBDA_L0 = norm(grad(model, rand(model.meta.nvar)), Inf) / 200

results = NamedTuple[]
(:TR in cfg.RUN_SOLVERS) && push!(results, run_tr_nnmf!(model, x0; λ = cfg.LAMBDA_L0, qn = cfg.QN_FOR_TR, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, sub_kwargs = cfg.SUB_KWARGS_R2N, selected = selected))
(:R2N in cfg.RUN_SOLVERS) && push!(results, run_r2n_nnmf!(model, x0; λ = cfg.LAMBDA_L0, qn = cfg.QN_FOR_R2N, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, sub_kwargs = cfg.SUB_KWARGS_R2N, selected = selected))
(:LM in cfg.RUN_SOLVERS) && push!(results, run_LM_nnmf!(nls_model, x0; λ = cfg.LAMBDA_L0, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, selected = selected, sub_kwargs = cfg.SUB_KWARGS_R2N))
(:LMTR in cfg.RUN_SOLVERS) && push!(results, run_LMTR_nnmf!(nls_model, x0; λ = cfg.LAMBDA_L0, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, selected = selected, sub_kwargs = cfg.SUB_KWARGS_R2N))

println("\n=== NNMF: solver comparison ===")
for m in results
println("\n→ ", m.name)
println(" status = ", m.status)
println(" time (s) = ", round(m.time, digits = 4))
m.iters !== missing && println(" outer iters = ", m.iters)
println(" # f eval = ", m.fevals)
println(" # ∇f eval = ", m.gevals)
m.proxcalls !== missing && println(" # prox calls = ", Int(m.proxcalls))
println(" final objective= ", round(obj(model, m.solution), digits = 4))
end

println("\nNNMF Config:"); print_config(cfg)

data_nnmf = [
(; name=m.name,
status=string(m.status),
time=round(m.time, digits=4),
fe=m.fevals,
ge=m.gevals,
prox = m.proxcalls === missing ? missing : Int(m.proxcalls),
obj = round(m.final_obj, digits=4))
for m in results
]

return data_nnmf
end

# #############################
# # ========= Main ========== #
# #############################

function main(latex_out = false)
function main(;latex_out = false)
data_svm = bench_svm!(CFG)
data_nnmf = bench_nnmf!(CFG2)

all_data = vcat(data_svm, data_nnmf)

println("\n=== Full Benchmark Table ===")
# what is inside the table
for row in all_data
for row in data_svm
println(row)
end

# save as latex format
if latex_out

table_str = pretty_table(String, all_data;
header = ["Method", "Status", L"$t$($s$)", L"$\#f$", L"$\#\nabla f$", L"$\#prox$", "Objective"],
backend = Val(:latex),
alignment = [:l, :c, :r, :r, :r, :r, :r],
)
table_str = pretty_table(String,
data_svm;
backend = :latex,
column_labels = ["Method", "Status", L"$t$($s$)", L"$\#f$", L"$\#\nabla f$", L"$\#prox$", "Objective"],
style = LatexTableStyle(column_label = String[]),
table_format = latex_table_format__booktabs
)

open("Benchmark.tex", "w") do io
write(io, table_str)
Expand Down
3 changes: 0 additions & 3 deletions paper/examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,5 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
NLPModelsModifiers = "e01155f1-5c6f-4375-a9d8-616dd036575f"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
RegularizedProblems = "ea076b23-609f-44d2-bb12-a4ae45328278"
ShiftedProximalOperators = "d4fd37fa-580c-4e43-9b30-361c21aae263"
1 change: 0 additions & 1 deletion paper/examples/comparison-config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@ end

# One global, constant *binding* to a mutable object = type stable & editable
const CFG = Config(QN_FOR_R2N=:LSR1)
const CFG2 = Config(QN_FOR_TR = :LBFGS)

end # module
4 changes: 2 additions & 2 deletions paper/examples/example1.jl → paper/examples/example.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearAlgebra, Random, ProximalOperators
using NLPModels, RegularizedProblems, RegularizedOptimization
using LinearAlgebra, Random, ShiftedProximalOperators
using NLPModels, NLPModelsModifiers, RegularizedProblems, RegularizedOptimization
using MLDatasets

Random.seed!(1234)
Expand Down