Skip to content

Commit 762840f

Browse files
stefanchstefanch
authored andcommitted
cleanup
1 parent 34b6448 commit 762840f

File tree

3 files changed

+11
-44
lines changed

3 files changed

+11
-44
lines changed

sgdml/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ def reset(command=None, **kwargs):
11811181
print(ui.info_str('[INFO]') + ' Benchmark cache is already empty.')
11821182
else:
11831183
print(' Cancelled.')
1184+
print('')
11841185

11851186

11861187
def main():

sgdml/predict.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(
211211
The parameters `batch_size` and `num_workers` are only
212212
relevant if this code runs on a CPU. Both can be set
213213
automatically via the function
214-
`set_opt_num_workers_and_batch_size_fast`. Enabling
214+
`prepare_parallel`. Enabling
215215
calculations via PyTorch is only recommended with GPU
216216
support. CPU calcuations are faster with our NumPy
217217
implementation.
@@ -354,7 +354,7 @@ def _set_num_workers(
354354
Note
355355
----
356356
This parameter can be optimally determined using
357-
`set_opt_num_workers_and_batch_size_fast`.
357+
`prepare_parallel`.
358358
359359
Parameters
360360
----------
@@ -414,7 +414,7 @@ def _set_batch_size(
414414
Note
415415
----
416416
This parameter can be optimally determined using
417-
`set_opt_num_workers_and_batch_size_fast`.
417+
`prepare_parallel`.
418418
419419
Parameters
420420
----------
@@ -441,6 +441,8 @@ def _set_bulk_mp(
441441

442442
def set_opt_num_workers_and_batch_size_fast(self, n_bulk=1, n_reps=1): # deprecated
443443
"""
444+
Warning
445+
-------
444446
Deprecated! Please use the function `prepare_parallel` in future projects.
445447
446448
Parameters
@@ -497,7 +499,11 @@ def prepare_parallel(self, n_bulk=1, n_reps=1, return_is_from_cache=False): # n
497499
n_reps : int, optional
498500
Number of repetitions (bigger value: more
499501
accurate, but also slower).
500-
502+
return_is_from_cache : bool, optional
503+
If enabled, this function returns a second value
504+
indicating if the returned results were obtained
505+
from cache.
506+
501507
Returns
502508
-------
503509
int
@@ -883,4 +889,3 @@ def predict(self, r):
883889
F = desc.r_to_d_desc_op(r, pdist, res[1:], self.ucell_size).reshape(1, -1)
884890
# F = res[1:].reshape(1,-1).dot(r_d_desc)
885891
return E, F
886-

sgdml/train.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,6 @@ def train( # noqa: C901
486486
j_end=M if use_cg else None,
487487
)
488488

489-
# if use_cg:
490-
# print(
491-
# ui.info_str('[INFO]')
492-
# + ' Nystroem preconditioner uses %s training points.' % M
493-
# )
494-
495489
# test 2
496490

497491
# use_ny = True
@@ -512,24 +506,13 @@ def train( # noqa: C901
512506

513507
lam = 1e-8
514508

515-
# print(M*R_d_desc.shape[2])
516-
517509
# ny_idxs = np.random.choice(K.shape[0], M*R_d_desc.shape[2], replace=False)
518-
519510
# K_mm = K[ny_idxs, :]
520511
# K_mm = K_mm[:, ny_idxs]
521-
522512
# K_nm = K[:, ny_idxs]
523513

524-
# P_inv3 = -(1.0/(-lam)) * (np.eye(K.shape[0]) - K_nm.dot(np.linalg.solve((K_mm + K_nm.T.dot(K_nm)),K_nm.T)))
525-
526-
# P_inv3_new = (-1.0/lam)*np.eye(K.shape[0]) - (-1.0/lam)*K_nm.dot(np.linalg.solve(((-lam)*K_mm + K_nm.T.dot(K_nm)),K_nm.T))
527-
# P_inv3 = P_inv3_new
528-
529514
_lup = sp.linalg.lu_factor((-lam) * K_mm + K_nm.T.dot(K_nm))
530515
def mv(v):
531-
# M = (-1.0/lam)*np.eye(K.shape[0]) - (-1.0/lam)*K_nm.dot(np.linalg.solve(((-lam)*K_mm + K_nm.T.dot(K_nm)),K_nm.T))
532-
# M = (-1.0/lam)*np.eye(K.shape[0]) - (-1.0/lam)*K_nm.dot(np.linalg.solve(((-lam)*K_mm + K_nm.T.dot(K_nm)),K_nm.T.dot(v)))
533516
P_v = -(-1.0 / lam) * (
534517
K_nm.dot(sp.linalg.lu_solve(_lup, K_nm.T.dot(v))) - v
535518
)
@@ -683,32 +666,10 @@ def callback(xk):
683666

684667
del K
685668

686-
# from scipy.sparse.linalg import cg
687-
688-
# alphas, status = cg(
689-
# -K_op,
690-
# y,
691-
# tol=1e-4,
692-
# maxiter=3 * n_atoms * n_train,
693-
# M=P_op,
694-
# callback=callback,
695-
# ) # M=P_inv3
696-
# alphas = -alphas
697-
698-
699-
#import scipy.sparse.linalg as spla
700-
701-
#M2 = spla.spilu(-K_op)
702-
#M = spla.LinearOperator(K_op.shape, M2.solve)
703-
704-
705669
from scipy.sparse.linalg import cg
706670
alphas, status = cg(-K_op, y, M=P_op, tol=1e-4, maxiter=3 * n_atoms * n_train, callback=callback)
707671
alphas = -alphas
708672

709-
# sys.stdout.flush()
710-
# print('\n')
711-
712673
# test 2
713674

714675
# alphas = K_mm_mn.dot(alphas) # remove me later

0 commit comments

Comments
 (0)