Skip to content

Commit 4eeff4e

Browse files
committed
remove custom namedtuple and use daatclasses
1 parent 15997cb commit 4eeff4e

20 files changed

+227
-155
lines changed

regain/bayesian/gwishart_inference.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939
from scipy import linalg
4040
from scipy.optimize import minimize
4141
from scipy.special import comb
42-
from sklearn.covariance import empirical_covariance
43-
from sklearn.covariance.empirical_covariance_ import log_likelihood
42+
from sklearn.covariance import empirical_covariance, log_likelihood
4443
from sklearn.linear_model import LassoLars
4544
from sklearn.utils import Bunch, check_array
4645
from sklearn.utils.extmath import fast_logdet
@@ -51,9 +50,11 @@
5150
def mk_all_ugs(n_dim):
5251
"""Utility for generating all possible graphs."""
5352
nedges = int(comb(n_dim, 2))
54-
m = 2 ** nedges
53+
m = 2**nedges
5554

56-
ind = np.array([list(binary_repr(x, width=len(binary_repr(m - 1)))) for x in range(m)]).astype(int)
55+
ind = np.array(
56+
[list(binary_repr(x, width=len(binary_repr(m - 1)))) for x in range(m)]
57+
).astype(int)
5758
ord = np.argsort(ind.sum(axis=1))
5859
ind = ind[ord]
5960

@@ -98,7 +99,12 @@ def score_blankets(blankets, X, alphas=(0.01, 0.5, 1)):
9899
X_mb = np.zeros((X.shape[0], 1))
99100

100101
y_mb = X[:, i]
101-
score = np.sum([LassoLars(alpha=alpha).fit(X_mb, y_mb).score(X_mb, y_mb) for alpha in alphas])
102+
score = np.sum(
103+
[
104+
LassoLars(alpha=alpha).fit(X_mb, y_mb).score(X_mb, y_mb)
105+
for alpha in alphas
106+
]
107+
)
102108

103109
scores.append(score)
104110
scores_all.append(scores)
@@ -109,7 +115,12 @@ def score_blankets(blankets, X, alphas=(0.01, 0.5, 1)):
109115

110116

111117
def _get_graphs(blankets, scores, n_dim, n_resampling=200):
112-
idx = np.array([np.random.choice(scores.shape[1], p=scores[i], size=n_resampling) for i in range(n_dim)])
118+
idx = np.array(
119+
[
120+
np.random.choice(scores.shape[1], p=scores[i], size=n_resampling)
121+
for i in range(n_dim)
122+
]
123+
)
113124

114125
graphs_ = np.array([blankets[i][idx[i]] for i in range(n_dim)]).transpose(1, 0, 2)
115126
# symmetrise with AND operator -> product
@@ -268,7 +279,11 @@ def compute_score(X, G, P, S, GWprior=None, score_method="bic"):
268279

269280
logdetHdiag = sum(np.log(-diagH))
270281
lognormconst = dof * np.log(2 * np.pi) / 2 + logh - logdetHdiag / 2.0
271-
score = lognormconst - GWprior.lognormconst - n_samples * n_dim * np.log(2 * np.pi) / 2
282+
score = (
283+
lognormconst
284+
- GWprior.lognormconst
285+
- n_samples * n_dim * np.log(2 * np.pi) / 2
286+
)
272287
GWpost.lognormconst = lognormconst
273288

274289
elif score_method == "laplace":
@@ -294,7 +309,11 @@ def compute_score(X, G, P, S, GWprior=None, score_method="bic"):
294309
# neg Hessian will be posdef
295310
logdetH = 2 * sum(np.log(np.diag(linalg.cholesky(-H))))
296311
lognormconst = dof * np.log(2 * np.pi) / 2 + logh - logdetH / 2.0
297-
score = lognormconst - GWprior.lognormconst - n_samples * n_dim * np.log(2 * np.pi) / 2
312+
score = (
313+
lognormconst
314+
- GWprior.lognormconst
315+
- n_samples * n_dim * np.log(2 * np.pi) / 2
316+
)
298317
GWpost.lognormconst = lognormconst
299318

300319
GWpost.score = score
@@ -315,7 +334,9 @@ def GWishartScore(X, G, d0=3, S0=None, score_method="bic", mode="covsel"):
315334
noData = np.zeros((0, n_dim))
316335

317336
P0, S_noData = GWishartFit(noData, G, GWprior)
318-
GWtemp = compute_score(noData, G, P0, S_noData, GWprior=GWprior, score_method=score_method)
337+
GWtemp = compute_score(
338+
noData, G, P0, S_noData, GWprior=GWprior, score_method=score_method
339+
)
319340
GWprior.lognormconst = GWtemp.lognormconst
320341

321342
# Compute the map precision matrix P
@@ -344,13 +365,17 @@ def bayesian_graphical_lasso(
344365
alphas = np.logspace(-2, 0, 20)
345366

346367
# get a series of Markov blankets for vaiours alphas
347-
mdl = GraphicalLasso(assume_centered=assume_centered, tol=tol, max_iter=max_iter, verbose=False)
368+
mdl = GraphicalLasso(
369+
assume_centered=assume_centered, tol=tol, max_iter=max_iter, verbose=False
370+
)
348371
precisions = [mdl.set_params(alpha=a).fit(X).precision_ for a in alphas]
349372
mblankets = markov_blankets(precisions, tol=tol, unique=True)
350373

351374
normalized_scores = score_blankets(mblankets, X=X, alphas=[0.01, 0.5, 1])
352375

353-
graphs = _get_graphs(mblankets, normalized_scores, n_dim=n_dim, n_resampling=n_resampling)
376+
graphs = _get_graphs(
377+
mblankets, normalized_scores, n_dim=n_dim, n_resampling=n_resampling
378+
)
354379

355380
nonzeros_all = [np.triu(g, 1) + np.eye(n_dim, dtype=bool) for g in graphs]
356381

@@ -361,7 +386,10 @@ def bayesian_graphical_lasso(
361386
# Find non-zero elements of upper triangle of G
362387
# make sure diagonal is non-zero
363388
# G = nonzeros_all[1] # probably can discard if all zeros?
364-
res = [GWishartScore(X, G, d0=d0, S0=S0, mode=mode, score_method=scoring) for G in nonzeros_all]
389+
res = [
390+
GWishartScore(X, G, d0=d0, S0=S0, mode=mode, score_method=scoring)
391+
for G in nonzeros_all
392+
]
365393

366394
top_n = [x.P for x in sorted(res, key=lambda x: x.score)[::-1][:top_n]]
367395
return np.mean(top_n, axis=0)
@@ -439,7 +467,12 @@ def __init__(
439467
top_n=1,
440468
):
441469
super(GraphicalLasso, self).__init__(
442-
alpha=alpha, tol=tol, max_iter=max_iter, verbose=verbose, assume_centered=assume_centered, mode=mode
470+
alpha=alpha,
471+
tol=tol,
472+
max_iter=max_iter,
473+
verbose=verbose,
474+
assume_centered=assume_centered,
475+
mode=mode,
443476
)
444477
self.alphas = alphas
445478
self.n_resampling = n_resampling

regain/bayesian/wishart_process_.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
import numpy as np
3535
from scipy import linalg, stats
3636
from sklearn.covariance import empirical_covariance
37-
from sklearn.datasets.base import Bunch
3837
from sklearn.gaussian_process import kernels
3938
from sklearn.metrics.pairwise import rbf_kernel
39+
from sklearn.utils import Bunch
4040
from sklearn.utils.validation import check_X_y
4141
from tqdm import trange
4242

@@ -333,17 +333,17 @@ def fit(self, X, y):
333333
samples_u, loglikes, lps = out
334334

335335
# Burn in
336-
self.lps_after_burnin = lps[self.burn_in :]
337-
self.samples_u_after_burnin = samples_u[self.burn_in :]
338-
self.loglikes_after_burnin = loglikes[self.burn_in :]
336+
self.lps_after_burnin = lps[self.burn_in:]
337+
self.samples_u_after_burnin = samples_u[self.burn_in:]
338+
self.loglikes_after_burnin = loglikes[self.burn_in:]
339339

340340
# % Select the best hyperparameters based on the loglikes_after_burnin
341341
pos = np.argmax(self.loglikes_after_burnin)
342342
self.lmap = self.lps_after_burnin[pos]
343343
self.u_map = self.samples_u_after_burnin[pos]
344344

345345
if self.learn_ell:
346-
self.Ls_after_burnin = Ls[self.burn_in :]
346+
self.Ls_after_burnin = Ls[self.burn_in:]
347347
self.Lmap = self.Ls_after_burnin[pos]
348348
else:
349349
self.Lmap = L

regain/covariance/graphical_lasso_.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from regain.norm import l1_od_norm
4545
from regain.prox import prox_logdet, soft_thresholding_off_diagonal
4646
from regain.update_rules import update_rho
47-
from regain.utils import convergence
47+
from regain.utils import Convergence
4848

4949

5050
def logl(emp_cov, precision):
@@ -169,7 +169,7 @@ def graphical_lasso(
169169
obj = objective(emp_cov, K, Z, alpha) if compute_objective else np.nan
170170
rnorm = np.linalg.norm(K - Z, "fro")
171171
snorm = rho * np.linalg.norm(Z - Z_old, "fro")
172-
check = convergence(
172+
check = Convergence(
173173
obj=obj,
174174
rnorm=rnorm,
175175
snorm=snorm,
@@ -180,10 +180,7 @@ def graphical_lasso(
180180

181181
Z_old = Z.copy()
182182
if verbose:
183-
print(
184-
"obj: %.4f, rnorm: %.4f, snorm: %.4f,"
185-
"eps_pri: %.4f, eps_dual: %.4f" % check[:5]
186-
)
183+
print(check)
187184

188185
checks.append(check)
189186
if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:

regain/covariance/infimal_convolution_.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from regain.norm import l1_od_norm
4141
from regain.prox import prox_laplacian, prox_trace_indicator, soft_thresholding
4242
from regain.update_rules import update_rho
43-
from regain.utils import convergence
43+
from regain.utils import Convergence
4444

4545

4646
def objective(S, R, K, L, alpha, tau):
@@ -136,24 +136,27 @@ def infimal_convolution(
136136
obj = objective(S, R, K, L, alpha, tau) if compute_objective else np.nan
137137
rnorm = np.linalg.norm(R - K + L)
138138
snorm = rho * np.linalg.norm(R - R_old)
139-
check = convergence(
139+
check = Convergence(
140140
obj=obj,
141141
rnorm=rnorm,
142142
snorm=snorm,
143-
e_pri=np.sqrt(R.size) * tol + rtol * max(np.linalg.norm(R), np.linalg.norm(K - L)),
143+
e_pri=np.sqrt(R.size) * tol
144+
+ rtol * max(np.linalg.norm(R), np.linalg.norm(K - L)),
144145
e_dual=np.sqrt(R.size) * tol + rtol * rho * np.linalg.norm(U),
145146
)
146147
R_old = R.copy()
147148

148149
if verbose:
149-
print("obj: %.4f, rnorm: %.4f, snorm: %.4f," "eps_pri: %.4f, eps_dual: %.4f" % check[:5])
150+
print(check)
150151

151152
checks.append(check)
152153
if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
153154
break
154155
if check.obj == np.inf:
155156
break
156-
rho_new = update_rho(rho, rnorm, snorm, iteration=iteration_, **(update_rho_options or {}))
157+
rho_new = update_rho(
158+
rho, rnorm, snorm, iteration=iteration_, **(update_rho_options or {})
159+
)
157160
# scaled dual variables should be also rescaled
158161
U *= rho / rho_new
159162
rho = rho_new

regain/covariance/kernel_latent_time_graphical_lasso_.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050
from regain.prox import prox_logdet, prox_trace_indicator, soft_thresholding
5151
from regain.update_rules import update_rho
52-
from regain.utils import convergence
52+
from regain.utils import Convergence
5353
from regain.validation import check_norm_prox
5454

5555

@@ -311,7 +311,7 @@ def kernel_latent_time_graphical_lasso(
311311
else np.nan
312312
)
313313

314-
check = convergence(
314+
check = Convergence(
315315
obj=obj,
316316
rnorm=rnorm,
317317
snorm=snorm,
@@ -360,10 +360,7 @@ def kernel_latent_time_graphical_lasso(
360360
W_M_old[m] = (W_M[m][0].copy(), W_M[m][1].copy())
361361

362362
if verbose:
363-
print(
364-
"obj: %.4f, rnorm: %.4f, snorm: %.4f,"
365-
"eps_pri: %.4f, eps_dual: %.4f" % check[:5]
366-
)
363+
print(check)
367364

368365
checks.append(check)
369366
if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:

regain/covariance/kernel_time_graphical_lasso_.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from regain.norm import l1_od_norm
5151
from regain.prox import prox_logdet, soft_thresholding
5252
from regain.update_rules import update_rho
53-
from regain.utils import convergence
53+
from regain.utils import Convergence
5454
from regain.validation import check_norm_prox
5555

5656

@@ -163,7 +163,7 @@ def kernel_time_graphical_lasso(
163163
n_samples = np.ones(n_times)
164164

165165
checks = [
166-
convergence(
166+
Convergence(
167167
obj=objective(n_samples, emp_cov, Z_0, Z_0, Z_M, alpha, kernel, psi)
168168
)
169169
]
@@ -245,7 +245,7 @@ def kernel_time_graphical_lasso(
245245
else np.nan
246246
)
247247

248-
check = convergence(
248+
check = Convergence(
249249
obj=obj,
250250
rnorm=rnorm,
251251
snorm=snorm,
@@ -283,10 +283,7 @@ def kernel_time_graphical_lasso(
283283
Z_M_old[m] = (Z_M[m][0].copy(), Z_M[m][1].copy())
284284

285285
if verbose:
286-
print(
287-
"obj: %.4f, rnorm: %.4f, snorm: %.4f,"
288-
"eps_pri: %.4f, eps_dual: %.4f" % check[:5]
289-
)
286+
print(check)
290287

291288
checks.append(check)
292289
if stop_at is not None:

regain/covariance/latent_graphical_lasso_.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434

3535
import numpy as np
3636
from scipy import linalg
37-
from six.moves import range
3837

39-
from regain.covariance.graphical_lasso_ import GraphicalLasso, init_precision
40-
from regain.covariance.graphical_lasso_ import objective as obj_gl
38+
from regain.covariance.graphical_lasso_ import (
39+
GraphicalLasso,
40+
init_precision,
41+
objective as obj_gl,
42+
)
4143
from regain.prox import prox_logdet, prox_trace_indicator, soft_thresholding
4244
from regain.update_rules import update_rho
43-
from regain.utils import convergence
45+
from regain.utils import Convergence
4446

4547

4648
def objective(emp_cov, R, K, L, alpha, tau):
@@ -144,24 +146,27 @@ def latent_graphical_lasso(
144146
obj = objective(emp_cov, R, K, L, alpha, tau) if compute_objective else np.nan
145147
rnorm = np.linalg.norm(R - K + L)
146148
snorm = rho * np.linalg.norm(R - R_old)
147-
check = convergence(
149+
check = Convergence(
148150
obj=obj,
149151
rnorm=rnorm,
150152
snorm=snorm,
151-
e_pri=np.sqrt(R.size) * tol + rtol * max(np.linalg.norm(R), np.linalg.norm(K - L)),
153+
e_pri=np.sqrt(R.size) * tol
154+
+ rtol * max(np.linalg.norm(R), np.linalg.norm(K - L)),
152155
e_dual=np.sqrt(R.size) * tol + rtol * rho * np.linalg.norm(U),
153156
)
154157
R_old = R.copy()
155158

156159
if verbose:
157-
print("obj: %.4f, rnorm: %.4f, snorm: %.4f," "eps_pri: %.4f, eps_dual: %.4f" % check[:5])
160+
print(check)
158161

159162
checks.append(check)
160163
if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
161164
break
162165
if check.obj == np.inf:
163166
break
164-
rho_new = update_rho(rho, rnorm, snorm, iteration=iteration_, **(update_rho_options or {}))
167+
rho_new = update_rho(
168+
rho, rnorm, snorm, iteration=iteration_, **(update_rho_options or {})
169+
)
165170
# scaled dual variables should be also rescaled
166171
U *= rho / rho_new
167172
rho = rho_new
@@ -293,7 +298,12 @@ def _fit(self, emp_cov):
293298
Empirical covariance of data.
294299
295300
"""
296-
self.precision_, self.latent_, self.covariance_, self.n_iter_ = latent_graphical_lasso(
301+
(
302+
self.precision_,
303+
self.latent_,
304+
self.covariance_,
305+
self.n_iter_,
306+
) = latent_graphical_lasso(
297307
emp_cov,
298308
alpha=self.alpha,
299309
tau=self.tau,

0 commit comments

Comments
 (0)