-
Notifications
You must be signed in to change notification settings - Fork 37
ENH Add FISTA solver #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
0868b0f
8584299
c82e32e
4940a0d
e47c68a
3635f24
4880112
46a9a76
9f0653a
fe159be
8e74e8a
a24ed9c
4362c2c
2665d5d
2e408bc
8524cf7
dd658f8
7c9fbe1
cbc5418
b6c664c
e76dfb1
2a4bce3
cd39a62
0e4d42a
2bbc8f5
ed3686a
aa15c46
27b918d
9d8e3c0
e5ce21b
5d2dbaf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,10 @@ class Quadratic(BaseDatafit): | |
The coordinatewise gradient Lipschitz constants. Equal to | ||
norm(X, axis=0) ** 2 / n_samples. | ||
|
||
global_lipschitz : float | ||
Global Lipschitz constant. Equal to | ||
norm(X, ord=2) ** 2 / n_samples. | ||
|
||
Note | ||
---- | ||
The class is jit compiled at fit time using Numba compiler. | ||
|
@@ -35,6 +39,7 @@ def get_spec(self): | |
spec = ( | ||
('Xty', float64[:]), | ||
('lipschitz', float64[:]), | ||
('global_lipschitz', float64), | ||
) | ||
return spec | ||
|
||
|
@@ -44,6 +49,7 @@ def params_to_dict(self): | |
def initialize(self, X, y): | ||
self.Xty = X.T @ y | ||
n_features = X.shape[1] | ||
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y) | ||
self.lipschitz = np.zeros(n_features, dtype=X.dtype) | ||
for j in range(n_features): | ||
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y) | ||
|
@@ -53,6 +59,7 @@ def initialize_sparse( | |
n_features = len(X_indptr) - 1 | ||
self.Xty = np.zeros(n_features, dtype=X_data.dtype) | ||
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype) | ||
self.global_lipschitz = 0. | ||
for j in range(n_features): | ||
nrm2 = 0. | ||
xty = 0 | ||
|
@@ -62,6 +69,7 @@ def initialize_sparse( | |
|
||
self.lipschitz[j] = nrm2 / len(y) | ||
self.Xty[j] = xty | ||
self.global_lipschitz += nrm2 / len(y) | ||
|
||
def value(self, y, w, Xw): | ||
return np.sum((y - Xw) ** 2) / (2 * len(Xw)) | ||
|
@@ -111,6 +119,10 @@ class Logistic(BaseDatafit): | |
The coordinatewise gradient Lipschitz constants. Equal to | ||
norm(X, axis=0) ** 2 / (4 * n_samples). | ||
|
||
global_lipschitz : float | ||
Global Lipschitz constant. Equal to | ||
norm(X, ord=2) ** 2 / (4 * n_samples). | ||
|
||
Note | ||
---- | ||
The class is jit compiled at fit time using Numba compiler. | ||
|
@@ -123,6 +135,7 @@ def __init__(self): | |
def get_spec(self): | ||
spec = ( | ||
('lipschitz', float64[:]), | ||
('global_lipschitz', float64), | ||
) | ||
return spec | ||
|
||
|
@@ -140,13 +153,16 @@ def raw_hessian(self, y, Xw): | |
|
||
def initialize(self, X, y): | ||
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4) | ||
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4) | ||
|
||
def initialize_sparse(self, X_data, X_indptr, X_indices, y): | ||
n_features = len(X_indptr) - 1 | ||
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype) | ||
self.global_lipschitz = 0. | ||
for j in range(n_features): | ||
Xj = X_data[X_indptr[j]:X_indptr[j+1]] | ||
self.lipschitz[j] = (Xj ** 2).sum() / (len(y) * 4) | ||
self.global_lipschitz += (Xj ** 2).sum() / (len(y) * 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that will yield a very crude bound, potentially with a loss or the order of Use a few iterations of the power method instead to approximate the lipschitz constant of the sparse matrix (there's also the Lanczos iteration but it's more complicated, let's implement the easy one first) |
||
|
||
def value(self, y, w, Xw): | ||
return np.log(1. + np.exp(- y * Xw)).sum() / len(y) | ||
|
@@ -187,6 +203,11 @@ class QuadraticSVC(BaseDatafit): | |
---------- | ||
lipschitz : array, shape (n_features,) | ||
The coordinatewise gradient Lipschitz constants. | ||
Equal to norm(yXT, axis=0) ** 2. | ||
|
||
global_lipschitz : float | ||
Global Lipschitz constant. Equal to | ||
norm(yXT, ord=2) ** 2. | ||
|
||
Note | ||
---- | ||
|
@@ -200,6 +221,7 @@ def __init__(self): | |
def get_spec(self): | ||
spec = ( | ||
('lipschitz', float64[:]), | ||
('global_lipschitz', float64), | ||
) | ||
return spec | ||
|
||
|
@@ -209,18 +231,21 @@ def params_to_dict(self): | |
def initialize(self, yXT, y): | ||
n_features = yXT.shape[1] | ||
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype) | ||
self.global_lipschitz = norm(yXT, ord=2) ** 2 | ||
for j in range(n_features): | ||
self.lipschitz[j] = norm(yXT[:, j]) ** 2 | ||
|
||
def initialize_sparse( | ||
self, yXT_data, yXT_indptr, yXT_indices, y): | ||
n_features = len(yXT_indptr) - 1 | ||
self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype) | ||
self.global_lipschitz = 0. | ||
for j in range(n_features): | ||
nrm2 = 0. | ||
for idx in range(yXT_indptr[j], yXT_indptr[j + 1]): | ||
nrm2 += yXT_data[idx] ** 2 | ||
self.lipschitz[j] = nrm2 | ||
self.global_lipschitz += nrm2 | ||
|
||
def value(self, y, w, yXTw): | ||
return (yXTw ** 2).sum() / 2 - np.sum(w) | ||
|
@@ -264,8 +289,16 @@ class Huber(BaseDatafit): | |
|
||
Attributes | ||
---------- | ||
delta : float | ||
Shape hyperparameter. | ||
Badr-MOUFAD marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
lipschitz : array, shape (n_features,) | ||
The coordinatewise gradient Lipschitz constants. | ||
The coordinatewise gradient Lipschitz constants. Equal to | ||
norm(X, axis=0) ** 2 / n_samples. | ||
|
||
global_lipschitz : float | ||
Global Lipschitz constant. Equal to | ||
norm(X, ord=2) ** 2 / n_samples. | ||
|
||
Note | ||
---- | ||
|
@@ -279,7 +312,8 @@ def __init__(self, delta): | |
def get_spec(self): | ||
spec = ( | ||
('delta', float64), | ||
('lipschitz', float64[:]) | ||
('lipschitz', float64[:]), | ||
('global_lipschitz', float64), | ||
) | ||
return spec | ||
|
||
|
@@ -289,18 +323,22 @@ def params_to_dict(self): | |
def initialize(self, X, y): | ||
n_features = X.shape[1] | ||
self.lipschitz = np.zeros(n_features, dtype=X.dtype) | ||
self.global_lipschitz = 0. | ||
for j in range(n_features): | ||
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y) | ||
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y) | ||
|
||
def initialize_sparse( | ||
self, X_data, X_indptr, X_indices, y): | ||
n_features = len(X_indptr) - 1 | ||
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype) | ||
self.global_lipschitz = 0. | ||
for j in range(n_features): | ||
nrm2 = 0. | ||
for idx in range(X_indptr[j], X_indptr[j + 1]): | ||
nrm2 += X_data[idx] ** 2 | ||
self.lipschitz[j] = nrm2 / len(y) | ||
self.global_lipschitz += nrm2 / len(y) | ||
|
||
def value(self, y, w, Xw): | ||
n_samples = len(y) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from .anderson_cd import AndersonCD | ||
from .base import BaseSolver | ||
from .fista import FISTA | ||
from .gram_cd import GramCD | ||
from .group_bcd import GroupBCD | ||
from .multitask_bcd import MultiTaskBCD | ||
from .prox_newton import ProxNewton | ||
|
||
|
||
__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] | ||
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import numpy as np | ||
from scipy.sparse import issparse | ||
from numba import njit | ||
from skglm.solvers.base import BaseSolver | ||
from skglm.solvers.common import construct_grad, construct_grad_sparse | ||
|
||
|
||
@njit | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _prox_vec(w, z, penalty, lipschitz): | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n_features = w.shape[0] | ||
for j in range(n_features): | ||
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j) | ||
return w | ||
|
||
|
||
class FISTA(BaseSolver): | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r"""ISTA solver with Nesterov acceleration (FISTA). | ||
|
||
This solver implements accelerated proximal gradient descent for linear problems. | ||
Badr-MOUFAD marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Attributes | ||
---------- | ||
max_iter : int, default 100 | ||
Maximum number of iterations. | ||
|
||
tol : float, default 1e-4 | ||
Tolerance for convergence. | ||
|
||
opt_freq : int, default 10 | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Frequency for optimality condition check. | ||
|
||
verbose : bool, default False | ||
Amount of verbosity. 0/False is silent. | ||
|
||
References | ||
---------- | ||
.. [1] Beck, A. and Teboulle M. | ||
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse | ||
problems", 2009, SIAM J. Imaging Sci. | ||
https://epubs.siam.org/doi/10.1137/080716542 | ||
""" | ||
|
||
def __init__(self, max_iter=100, tol=1e-4, opt_freq=10, verbose=0): | ||
self.max_iter = max_iter | ||
self.tol = tol | ||
self.opt_freq = opt_freq | ||
self.verbose = verbose | ||
|
||
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): | ||
n_samples, n_features = X.shape | ||
all_features = np.arange(n_features) | ||
t_new = 1 | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
w = w_init.copy() if w_init is not None else np.zeros(n_features) | ||
z = w_init.copy() if w_init is not None else np.zeros(n_features) | ||
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples) | ||
|
||
if hasattr(datafit, "global_lipschitz"): | ||
lipschitz = datafit.global_lipschitz | ||
else: | ||
# TODO: OR line search | ||
raise Exception("Line search is not yet implemented for FISTA solver.") | ||
|
||
for n_iter in range(self.max_iter): | ||
t_old = t_new | ||
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 | ||
w_old = w.copy() | ||
if issparse(X): | ||
grad = construct_grad_sparse( | ||
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features) | ||
else: | ||
grad = construct_grad(X, y, z, X @ z, datafit, all_features) | ||
z -= grad / lipschitz | ||
w = _prox_vec(w, z, penalty, lipschitz) | ||
Xw = X @ w | ||
z = w + (t_old - 1.) / t_new * (w - w_old) | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
mathurinm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if n_iter % self.opt_freq == 0: | ||
opt = penalty.subdiff_distance(w, grad, all_features) | ||
stop_crit = np.max(opt) | ||
|
||
if self.verbose: | ||
p_obj = datafit.value(y, w, Xw) + penalty.value(w) | ||
print( | ||
f"Iteration {n_iter+1}: {p_obj:.10f}, " | ||
f"stopping crit: {stop_crit:.2e}" | ||
) | ||
|
||
if stop_crit < self.tol: | ||
if self.verbose: | ||
print(f"Stopping criterion max violation: {stop_crit:.2e}") | ||
break | ||
return w | ||
PABannier marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
from numpy.linalg import norm | ||
from scipy.sparse import csc_matrix, issparse | ||
|
||
from skglm.datafits import Quadratic, Logistic, QuadraticSVC | ||
from skglm.penalties import L1, IndicatorBox | ||
from skglm.solvers import FISTA, AndersonCD | ||
from skglm.utils import make_correlated_data, compiled_clone | ||
|
||
|
||
np.random.seed(0) | ||
n_samples, n_features = 50, 60 | ||
X, y, _ = make_correlated_data( | ||
n_samples=n_samples, n_features=n_features, random_state=0) | ||
X_sparse = csc_matrix(X * np.random.binomial(1, 0.1, X.shape)) | ||
y_classif = np.sign(y) | ||
|
||
alpha_max = norm(X.T @ y, ord=np.inf) / len(y) | ||
alpha = alpha_max / 100 | ||
|
||
tol = 1e-10 | ||
|
||
|
||
@pytest.mark.parametrize("X", [X, X_sparse]) | ||
@pytest.mark.parametrize("Datafit, Penalty", [ | ||
(Quadratic, L1), | ||
(Logistic, L1), | ||
(QuadraticSVC, IndicatorBox), | ||
]) | ||
def test_fista_solver(X, Datafit, Penalty): | ||
_y = y if isinstance(Datafit, Quadratic) else y_classif | ||
datafit = compiled_clone(Datafit()) | ||
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X | ||
if issparse(X): | ||
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y) | ||
else: | ||
datafit.initialize(_init, _y) | ||
penalty = compiled_clone(Penalty(alpha)) | ||
|
||
solver = FISTA(max_iter=1000, tol=tol, opt_freq=1) | ||
w = solver.solve(X, _y, datafit, penalty) | ||
|
||
solver_cd = AndersonCD(tol=tol, fit_intercept=False) | ||
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0] | ||
|
||
np.testing.assert_allclose(w, w_cd, rtol=1e-3) | ||
|
||
|
||
if __name__ == '__main__': | ||
pass |
Uh oh!
There was an error while loading. Please reload this page.