Skip to content

ENH Add FISTA solver #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0868b0f
POC FISTA
PABannier Oct 12, 2022
8584299
CLN
PABannier Oct 14, 2022
c82e32e
changed obj_freq from 100 to 10
PABannier Oct 14, 2022
4940a0d
WIP Lipschitz
PABannier Oct 14, 2022
e47c68a
ADD global lipschitz constants
PABannier Oct 14, 2022
3635f24
FISTA with global lipschitz
PABannier Oct 14, 2022
4880112
writing tests
PABannier Oct 14, 2022
46a9a76
better tests
PABannier Oct 14, 2022
9f0653a
support sparse matrices
PABannier Oct 14, 2022
fe159be
fix mistake
PABannier Oct 14, 2022
8e74e8a
RM toy_fista
PABannier Oct 14, 2022
a24ed9c
green
PABannier Oct 14, 2022
4362c2c
mv `_prox_vec` to utils
PABannier Oct 16, 2022
2665d5d
rm `opt_freq`
PABannier Oct 16, 2022
2e408bc
fix tests
PABannier Oct 16, 2022
8524cf7
Update skglm/solvers/fista.py
PABannier Oct 16, 2022
dd658f8
huber comment
PABannier Oct 16, 2022
7c9fbe1
Merge branch 'fista' of https://github.com/PABannier/skglm into fista
PABannier Oct 16, 2022
cbc5418
WIP
PABannier Oct 16, 2022
b6c664c
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Oct 20, 2022
e76dfb1
implement power method
Badr-MOUFAD Oct 20, 2022
2a4bce3
private ``prox_vec``
Badr-MOUFAD Oct 20, 2022
cd39a62
random init in power method && default args
Badr-MOUFAD Oct 21, 2022
0e4d42a
use power method for ``global_lipschitz``
Badr-MOUFAD Oct 21, 2022
2bbc8f5
fix && refactor unittest
Badr-MOUFAD Oct 21, 2022
ed3686a
add docs for tol and max_iter && clean ups
Badr-MOUFAD Oct 21, 2022
aa15c46
remove square form spectral norm
Badr-MOUFAD Oct 21, 2022
27b918d
refactor ``_prox_vec`` function
Badr-MOUFAD Oct 21, 2022
9d8e3c0
fix bug segmentation fault
Badr-MOUFAD Oct 21, 2022
e5ce21b
add Fista to docs && fix unittest
Badr-MOUFAD Oct 21, 2022
5d2dbaf
cosmetic changes
mathurinm Oct 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class Quadratic(BaseDatafit):
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / n_samples.

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / n_samples.

Note
----
The class is jit compiled at fit time using Numba compiler.
Expand All @@ -35,6 +39,7 @@ def get_spec(self):
spec = (
('Xty', float64[:]),
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -44,6 +49,7 @@ def params_to_dict(self):
def initialize(self, X, y):
self.Xty = X.T @ y
n_features = X.shape[1]
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
for j in range(n_features):
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
Expand All @@ -53,6 +59,7 @@ def initialize_sparse(
n_features = len(X_indptr) - 1
self.Xty = np.zeros(n_features, dtype=X_data.dtype)
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
nrm2 = 0.
xty = 0
Expand All @@ -62,6 +69,7 @@ def initialize_sparse(

self.lipschitz[j] = nrm2 / len(y)
self.Xty[j] = xty
self.global_lipschitz += nrm2 / len(y)

def value(self, y, w, Xw):
return np.sum((y - Xw) ** 2) / (2 * len(Xw))
Expand Down Expand Up @@ -111,6 +119,10 @@ class Logistic(BaseDatafit):
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / (4 * n_samples).

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / (4 * n_samples).

Note
----
The class is jit compiled at fit time using Numba compiler.
Expand All @@ -123,6 +135,7 @@ def __init__(self):
def get_spec(self):
spec = (
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -140,13 +153,16 @@ def raw_hessian(self, y, Xw):

def initialize(self, X, y):
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4)

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
self.lipschitz[j] = (Xj ** 2).sum() / (len(y) * 4)
self.global_lipschitz += (Xj ** 2).sum() / (len(y) * 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will yield a very crude bound, potentially with a loss or the order of n_features.

Use a few iterations of the power method instead to approximate the lipschitz constant of the sparse matrix (there's also the Lanczos iteration but it's more complicated, let's implement the easy one first)


def value(self, y, w, Xw):
return np.log(1. + np.exp(- y * Xw)).sum() / len(y)
Expand Down Expand Up @@ -187,6 +203,11 @@ class QuadraticSVC(BaseDatafit):
----------
lipschitz : array, shape (n_features,)
The coordinatewise gradient Lipschitz constants.
Equal to norm(yXT, axis=0) ** 2.

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(yXT, ord=2) ** 2.

Note
----
Expand All @@ -200,6 +221,7 @@ def __init__(self):
def get_spec(self):
spec = (
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -209,18 +231,21 @@ def params_to_dict(self):
def initialize(self, yXT, y):
n_features = yXT.shape[1]
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
self.global_lipschitz = norm(yXT, ord=2) ** 2
for j in range(n_features):
self.lipschitz[j] = norm(yXT[:, j]) ** 2

def initialize_sparse(
self, yXT_data, yXT_indptr, yXT_indices, y):
n_features = len(yXT_indptr) - 1
self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
nrm2 = 0.
for idx in range(yXT_indptr[j], yXT_indptr[j + 1]):
nrm2 += yXT_data[idx] ** 2
self.lipschitz[j] = nrm2
self.global_lipschitz += nrm2

def value(self, y, w, yXTw):
return (yXTw ** 2).sum() / 2 - np.sum(w)
Expand Down Expand Up @@ -264,8 +289,16 @@ class Huber(BaseDatafit):

Attributes
----------
delta : float
Shape hyperparameter.

lipschitz : array, shape (n_features,)
The coordinatewise gradient Lipschitz constants.
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / n_samples.

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / n_samples.

Note
----
Expand All @@ -279,7 +312,8 @@ def __init__(self, delta):
def get_spec(self):
spec = (
('delta', float64),
('lipschitz', float64[:])
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -289,18 +323,22 @@ def params_to_dict(self):
def initialize(self, X, y):
n_features = X.shape[1]
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y)

def initialize_sparse(
self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
nrm2 = 0.
for idx in range(X_indptr[j], X_indptr[j + 1]):
nrm2 += X_data[idx] ** 2
self.lipschitz[j] = nrm2 / len(y)
self.global_lipschitz += nrm2 / len(y)

def value(self, y, w, Xw):
n_samples = len(y)
Expand Down
3 changes: 2 additions & 1 deletion skglm/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .anderson_cd import AndersonCD
from .base import BaseSolver
from .fista import FISTA
from .gram_cd import GramCD
from .group_bcd import GroupBCD
from .multitask_bcd import MultiTaskBCD
from .prox_newton import ProxNewton


__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
93 changes: 93 additions & 0 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np
from scipy.sparse import issparse
from numba import njit
from skglm.solvers.base import BaseSolver
from skglm.solvers.common import construct_grad, construct_grad_sparse


@njit
def _prox_vec(w, z, penalty, lipschitz):
n_features = w.shape[0]
for j in range(n_features):
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
return w


class FISTA(BaseSolver):
r"""ISTA solver with Nesterov acceleration (FISTA).

This solver implements accelerated proximal gradient descent for linear problems.

Attributes
----------
max_iter : int, default 100
Maximum number of iterations.

tol : float, default 1e-4
Tolerance for convergence.

opt_freq : int, default 10
Frequency for optimality condition check.

verbose : bool, default False
Amount of verbosity. 0/False is silent.

References
----------
.. [1] Beck, A. and Teboulle M.
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
problems", 2009, SIAM J. Imaging Sci.
https://epubs.siam.org/doi/10.1137/080716542
"""

def __init__(self, max_iter=100, tol=1e-4, opt_freq=10, verbose=0):
self.max_iter = max_iter
self.tol = tol
self.opt_freq = opt_freq
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
n_samples, n_features = X.shape
all_features = np.arange(n_features)
t_new = 1

w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

if hasattr(datafit, "global_lipschitz"):
lipschitz = datafit.global_lipschitz
else:
# TODO: OR line search
raise Exception("Line search is not yet implemented for FISTA solver.")

for n_iter in range(self.max_iter):
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
if issparse(X):
grad = construct_grad_sparse(
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
else:
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
z -= grad / lipschitz
w = _prox_vec(w, z, penalty, lipschitz)
Xw = X @ w
z = w + (t_old - 1.) / t_new * (w - w_old)

if n_iter % self.opt_freq == 0:
opt = penalty.subdiff_distance(w, grad, all_features)
stop_crit = np.max(opt)

if self.verbose:
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
print(
f"Iteration {n_iter+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit < self.tol:
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break
return w
52 changes: 52 additions & 0 deletions skglm/tests/test_fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

import numpy as np
from numpy.linalg import norm
from scipy.sparse import csc_matrix, issparse

from skglm.datafits import Quadratic, Logistic, QuadraticSVC
from skglm.penalties import L1, IndicatorBox
from skglm.solvers import FISTA, AndersonCD
from skglm.utils import make_correlated_data, compiled_clone


np.random.seed(0)
n_samples, n_features = 50, 60
X, y, _ = make_correlated_data(
n_samples=n_samples, n_features=n_features, random_state=0)
X_sparse = csc_matrix(X * np.random.binomial(1, 0.1, X.shape))
y_classif = np.sign(y)

alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
alpha = alpha_max / 100

tol = 1e-10


@pytest.mark.parametrize("X", [X, X_sparse])
@pytest.mark.parametrize("Datafit, Penalty", [
(Quadratic, L1),
(Logistic, L1),
(QuadraticSVC, IndicatorBox),
])
def test_fista_solver(X, Datafit, Penalty):
_y = y if isinstance(Datafit, Quadratic) else y_classif
datafit = compiled_clone(Datafit())
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
if issparse(X):
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y)
else:
datafit.initialize(_init, _y)
penalty = compiled_clone(Penalty(alpha))

solver = FISTA(max_iter=1000, tol=tol, opt_freq=1)
w = solver.solve(X, _y, datafit, penalty)

solver_cd = AndersonCD(tol=tol, fit_intercept=False)
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]

np.testing.assert_allclose(w, w_cd, rtol=1e-3)


if __name__ == '__main__':
pass