Skip to content

Commit 8af877a

Browse files
committed
release candidate
1 parent 6f1364e commit 8af877a

File tree

8 files changed

+110
-75
lines changed

8 files changed

+110
-75
lines changed

sgdml/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323
# SOFTWARE.
2424

25-
__version__ = '1.0.0.dev0'
25+
__version__ = '1.0.0'
2626

2727
MAX_PRINT_WIDTH = 100
2828
LOG_LEVELNAME_WIDTH = 7 # do not modify
@@ -109,7 +109,7 @@ def __init__(self, name):
109109
hd.setFormatter(formatter)
110110
hd.setLevel(
111111
logging.INFO
112-
) # control logging level here
112+
) # control logging level here
113113

114114
self.addHandler(hd)
115115
return

sgdml/cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def _print_splash(max_memory, max_processes, use_torch):
140140
'You can update your installation by running \'pip install sgdml --upgrade\'.'
141141
)
142142

143-
144143
_print_billboard()
145144

146145

@@ -187,6 +186,7 @@ def _print_billboard():
187186
bbs = None
188187
try:
189188
import json
189+
190190
bbs = json.loads(resp_str)
191191
except:
192192
pass
@@ -1174,7 +1174,7 @@ def _online_err(err, size, n, mae_n_sum, rmse_n_sum):
11741174
mae_n_sum += np.sum(err) / size
11751175
mae = mae_n_sum / n
11761176

1177-
rmse_n_sum += np.sum(err ** 2) / size
1177+
rmse_n_sum += np.sum(err**2) / size
11781178
rmse = np.sqrt(rmse_n_sum / n)
11791179

11801180
return mae, mae_n_sum, rmse, rmse_n_sum
@@ -1768,7 +1768,7 @@ def test(
17681768
'rmse': e_rmse.item(),
17691769
}
17701770

1771-
model['f_err'] = {'mae':f_mae.item(), 'rmse': f_rmse.item()}
1771+
model['f_err'] = {'mae': f_mae.item(), 'rmse': f_rmse.item()}
17721772
np.savez_compressed(model_path, **model)
17731773

17741774
if is_test and model['n_test'] > 0:
@@ -1994,7 +1994,7 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
19941994
)
19951995

19961996
# Available resources
1997-
total_memory = psutil.virtual_memory().total // 2 ** 30
1997+
total_memory = psutil.virtual_memory().total // 2**30
19981998
total_cpus = mp.cpu_count()
19991999

20002000
parser = argparse.ArgumentParser()

sgdml/get.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# MIT License
44
#
5-
# Copyright (c) 2018-2021 Stefan Chmiela
5+
# Copyright (c) 2018-2023 Stefan Chmiela
66
#
77
# Permission is hereby granted, free of charge, to any person obtaining a copy
88
# of this software and associated documentation files (the "Software"), to deal
@@ -131,15 +131,15 @@ def main():
131131
print()
132132
print('Available %ss:' % args.command)
133133

134-
print('{:<2} {:<25} {:>4}'.format('ID', 'Name', 'Size'))
135-
print('-' * 36)
134+
print('{:<2} {:<31} {:>4}'.format('ID', 'Name', 'Size'))
135+
print('-' * 42)
136136

137137
items = line[0].split(b';')
138138
for i, item in enumerate(items):
139139
name, size = item.split(b',')
140-
size = int(size) / 1024 ** 2 # Bytes to MBytes
140+
size = int(size) / 1024**2 # Bytes to MBytes
141141

142-
print('{:>2d} {:<24} {:>5.1f} MB'.format(i, name.decode("utf-8"), size))
142+
print('{:>2d} {:<30} {:>5.1f} MB'.format(i, name.decode("utf-8"), size))
143143
print()
144144

145145
down_list = raw_input(
@@ -155,12 +155,12 @@ def main():
155155
down_idxs = [int(idx) for idx in re.split(r'\s+', down_list.strip())]
156156
down_idxs = list(set(down_idxs))
157157
else:
158-
print(ui.warn_str('ABORTED'))
158+
print(ui.color_str('ABORTED', fore_color=ui.RED, bold=True))
159159

160160
for idx in down_idxs:
161161
if idx not in range(len(items)):
162162
print(
163-
ui.warn_str('[WARN]')
163+
ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
164164
+ ' Index '
165165
+ str(idx)
166166
+ ' out of range, skipping.'

sgdml/predict.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _predict_wkr(
172172

173173
# avoid divisions (slower)
174174
sig_inv = 1.0 / sig
175-
mat52_base_fact = 5.0 / (3 * sig ** 3)
175+
mat52_base_fact = 5.0 / (3 * sig**3)
176176
diag_scale_fact = 5.0 / sig
177177
sqrt5 = np.sqrt(5.0)
178178

@@ -313,7 +313,7 @@ def __init__(
313313
if log_level is not None:
314314
self.log.setLevel(log_level)
315315

316-
total_memory = psutil.virtual_memory().total // 2 ** 30 # bytes to GB)
316+
total_memory = psutil.virtual_memory().total // 2**30 # bytes to GB)
317317
self.max_memory = (
318318
min(max_memory, total_memory) if max_memory is not None else total_memory
319319
)
@@ -378,14 +378,14 @@ def __init__(
378378
self.torch_predict = torch.nn.DataParallel(self.torch_predict)
379379

380380
# Send model to device
381-
#self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
381+
# self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
382382
if _torch_cuda_is_available:
383383
self.torch_device = 'cuda'
384384
elif _torch_mps_is_available:
385385
self.torch_device = 'mps'
386386
else:
387387
self.torch_device = 'cpu'
388-
388+
389389
while True:
390390
try:
391391
self.torch_predict.to(self.torch_device)
@@ -405,9 +405,9 @@ def __init__(
405405
model.set_n_perm_batches(
406406
model.get_n_perm_batches() + 1
407407
) # uncache
408-
#self.torch_predict.to( # NOTE!
408+
# self.torch_predict.to( # NOTE!
409409
# self.torch_device
410-
#) # try sending to device again
410+
# ) # try sending to device again
411411
pass
412412
else:
413413
self.log.critical(
@@ -1194,12 +1194,16 @@ def predict(self, R=None, return_E=True):
11941194
print()
11951195
os._exit(1)
11961196
else:
1197-
R_torch = torch.from_numpy(R.reshape(-1, self.n_atoms, 3)).type(torch.float32).to(
1198-
self.torch_device
1197+
R_torch = (
1198+
torch.from_numpy(R.reshape(-1, self.n_atoms, 3))
1199+
.type(torch.float32)
1200+
.to(self.torch_device)
11991201
)
12001202

12011203
model = self.torch_predict
1202-
if R_torch.shape[0] < torch.cuda.device_count() and isinstance(model, torch.nn.DataParallel):
1204+
if R_torch.shape[0] < torch.cuda.device_count() and isinstance(
1205+
model, torch.nn.DataParallel
1206+
):
12031207
model = self.torch_predict.module
12041208
E_torch_F_torch = model.forward(R_torch, return_E=return_E)
12051209

sgdml/solvers/iterative.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _init_kernel_operator(
148148
n_train = R_desc.shape[0]
149149

150150
# dummy alphas
151-
v_F = np.zeros((n-n_train, 1)) if task['use_E_cstr'] else np.zeros((n, 1))
151+
v_F = np.zeros((n - n_train, 1)) if task['use_E_cstr'] else np.zeros((n, 1))
152152
v_E = np.zeros((n_train, 1)) if task['use_E_cstr'] else None
153153

154154
# Note: The standard deviation is set to 1.0, because we are predicting normalized labels here.
@@ -372,12 +372,16 @@ def _lev_scores(
372372
dim_m = dim_i * min(n_inducing_pts, 10)
373373

374374
# Which columns to use for leverage score approximation?
375-
lev_approx_idxs = np.sort(np.random.choice(n_train*dim_i + (n_train if use_E_cstr else 0), dim_m, replace=False)) # random subset of columns
376-
#lev_approx_idxs = np.sort(np.random.choice(n_train*dim_i, dim_m, replace=False)) # random subset of columns
377-
378-
#lev_approx_idxs = np.s_[
375+
lev_approx_idxs = np.sort(
376+
np.random.choice(
377+
n_train * dim_i + (n_train if use_E_cstr else 0), dim_m, replace=False
378+
)
379+
) # random subset of columns
380+
# lev_approx_idxs = np.sort(np.random.choice(n_train*dim_i, dim_m, replace=False)) # random subset of columns
381+
382+
# lev_approx_idxs = np.s_[
379383
# :dim_m
380-
#] # first 'dim_m' columns (faster kernel construction)
384+
# ] # first 'dim_m' columns (faster kernel construction)
381385

382386
L_inv_K_mn = self._nystroem_cholesky_factor(
383387
R_desc,
@@ -460,7 +464,7 @@ def _cho_factor_stable(self, M, pre_reg=False, eps_mag_max=1):
460464

461465
self.log.critical(
462466
'Failed to factorize despite strong regularization (max: {})!\nYou could try a larger sigma.'.format(
463-
10.0 ** eps_mag_max
467+
10.0**eps_mag_max
464468
)
465469
)
466470
print()
@@ -492,7 +496,7 @@ def solve(
492496
num_iters0 = task['solver_iters'] if 'solver_iters' in task else 0
493497

494498
# Number of inducing points to use for Nystrom approximation.
495-
max_memory_bytes = self._max_memory * 1024 ** 3
499+
max_memory_bytes = self._max_memory * 1024**3
496500
max_n_inducing_pts = Iterative.max_n_inducing_pts(
497501
n_train, n_atoms, max_memory_bytes
498502
)
@@ -849,13 +853,12 @@ def max_n_inducing_pts(n_train, n_atoms, max_memory_bytes):
849853
ny_factor = SQUARE_FACT * to_dof
850854

851855
n_inducing_pts = (
852-
np.sqrt(sq_factor ** 2 + 4.0 * ny_factor * max_memory_bytes) - sq_factor
856+
np.sqrt(sq_factor**2 + 4.0 * ny_factor * max_memory_bytes) - sq_factor
853857
) / (2 * ny_factor)
854858
n_inducing_pts = int(n_inducing_pts)
855859

856860
return min(n_inducing_pts, n_train)
857861

858-
859862
@staticmethod
860863
def est_memory_requirement(n_train, n_inducing_pts, n_atoms):
861864

0 commit comments

Comments
 (0)