diff --git a/src/shiftedGroupNormL2Binf.jl b/src/shiftedGroupNormL2Binf.jl index 162e4e85..4ae4a312 100644 --- a/src/shiftedGroupNormL2Binf.jl +++ b/src/shiftedGroupNormL2Binf.jl @@ -33,7 +33,7 @@ end function (ψ::ShiftedGroupNormL2Binf)(y) @. ψ.xsy = ψ.sj + y - indball_val = IndBallLinf(1.1 * ψ.Δ)(ψ.xsy) + indball_val = IndBallLinf(1.01 * ψ.Δ)(ψ.xsy) ψ.xsy .+= ψ.xk return ψ.h(ψ.xsy) + indball_val end @@ -78,40 +78,42 @@ function prox!( V2 <: AbstractVector{R}, } ψ.sol .= q .+ ψ.xk .+ ψ.sj - ϵ = 1 ## sasha's initial guess + softthres(x, a) = sign.(x) .* max.(0, abs.(x) .- a) l2prox(x, a) = max(0, 1 - a / norm(x)) .* x + linfproj(x) = max.(-ψ.Δ, min.(ψ.Δ, x)) for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) σλ = λ * σ - ## find root for each block froot(n) = n - norm( - σ .* softthres( - (ψ.sol[idx] ./ σ .- (n / (σ * (n - σλ))) .* ψ.xk[idx]), - ψ.Δ * (n / (σ * (n - σλ))), - ) .- ψ.sol[idx], + (n / (n - σλ)) .* ( + ψ.xk[idx] .+ linfproj( + ((n - σλ)/n) .* ψ.sol[idx] .- ψ.xk[idx] + ) + ) ) - lmin = σλ * (1 + eps(R)) # lower bound - fl = froot(lmin) - - ansatz = lmin + ϵ #ansatz for upper bound - step = ansatz / (σ * (ansatz - σλ)) - zlmax = norm(softthres((ψ.sol[idx] ./ σ .- step .* ψ.xk[idx]), ψ.Δ * step)) - lmax = norm(ψ.sol[idx]) + σ * (zlmax + abs((ϵ - 1) / ϵ + 1) * λ * norm(ψ.xk[idx])) - fm = froot(lmax) - if fl * fm > 0 + xlength = length(ψ.xk[idx]) + xnorminf = ψ.χ(ψ.xk[idx]) + xnorm = norm(ψ.xk[idx]) + qnorm = norm(ψ.sol[idx]) + qproj = sum(linfproj(ψ.sol[idx] .- ψ.xk[idx]) - ψ.xk[idx]) + τ = 1e4*eps(R) + n = NaN + if xnorminf > ψ.Δ #case 1 + n = find_zero(froot, (σλ + xnorminf - ψ.Δ, σλ + xnorm + ψ.Δ*√xlength), Roots.A42()) + elseif qnorm > σλ && xnorminf < ψ.Δ #case 2a + n = find_zero(froot, (σλ + τ, qnorm + xnorm + ψ.Δ*√xlength), Roots.A42()) + elseif qnorm > σλ && xnorm ≈ ψ.Δ && qproj ≈ 0.0 #case 4a + n = find_zero(froot, (σλ + τ, min(σλ + xnorm + ψ.Δ*sqrt(n), qnorm)), Roots.A42()) + end + if isnan(n) y[idx] .= 0 else - n = fzero(froot, lmin, lmax) step = n / (σ * (n - σλ)) - if abs(n - σλ) ≈ 0 - y[idx] .= 0 - else - y[idx] .= l2prox( - ψ.sol[idx] .- σ .* softthres((ψ.sol[idx] ./ σ .- step .* ψ.xk[idx]), ψ.Δ * step), - σλ, - ) - end + y[idx] .= l2prox( + ψ.sol[idx] .- σ .* softthres((ψ.sol[idx] ./ σ .- step .* ψ.xk[idx]), ψ.Δ * step), + σλ, + ) end y[idx] .-= (ψ.xk[idx] + ψ.sj[idx]) end