Skip to content

Commit 7f66196

Browse files
committed
Update model selection function.
1 parent 6b0f4bd commit 7f66196

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

gklearn/utils/model_selection_precomputed.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def model_selection_for_precomputed_kernel(datafile,
3030
datafile_y=None,
3131
extra_params=None,
3232
ds_name='ds-unknown',
33+
output_dir='outputs/',
3334
n_jobs=1,
3435
read_gm_from_file=False,
3536
verbose=True):
@@ -56,7 +57,7 @@ def model_selection_for_precomputed_kernel(datafile,
5657
model_type : string
5758
Type of the problem, can be 'regression' or 'classification'.
5859
NUM_TRIALS : integer
59-
Number of random trials of outer cv loop. The default is 30.
60+
Number of random trials of the outer CV loop. The default is 30.
6061
datafile_y : string
6162
Path of file storing y data. This parameter is optional depending on
6263
the given dataset file.
@@ -89,9 +90,9 @@ def model_selection_for_precomputed_kernel(datafile,
8990
"""
9091
tqdm.monitor_interval = 0
9192

92-
results_dir = '../notebooks/results/' + estimator.__name__
93-
if not os.path.exists(results_dir):
94-
os.makedirs(results_dir)
93+
output_dir += estimator.__name__
94+
if not os.path.exists(output_dir):
95+
os.makedirs(output_dir)
9596
# a string to save all the results.
9697
str_fw = '###################### log time: ' + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + '. ######################\n\n'
9798
str_fw += '# This file contains results of ' + estimator.__name__ + ' on dataset ' + ds_name + ',\n# including gram matrices, serial numbers for gram matrix figures and performance.\n\n'
@@ -209,7 +210,7 @@ def model_selection_for_precomputed_kernel(datafile,
209210
# threshold=np.inf,
210211
# floatmode='unique') + '\n\n'
211212

212-
fig_file_name = results_dir + '/GM[ds]' + ds_name
213+
fig_file_name = output_dir + '/GM[ds]' + ds_name
213214
if params_out != {}:
214215
fig_file_name += '[params]' + str(idx)
215216
plt.imshow(Kmatrix)
@@ -244,7 +245,7 @@ def model_selection_for_precomputed_kernel(datafile,
244245
str_fw += '\nall gram matrices are ignored, no results obtained.\n\n'
245246
else:
246247
# save gram matrices to file.
247-
# np.savez(results_dir + '/' + ds_name + '.gm',
248+
# np.savez(output_dir + '/' + ds_name + '.gm',
248249
# gms=gram_matrices, params=param_list_pre_revised, y=y,
249250
# gmtime=gram_matrix_time)
250251
if verbose:
@@ -450,7 +451,7 @@ def init_worker(gms_toshare):
450451
print()
451452
print('2. Reading gram matrices from file...')
452453
str_fw += '\nII. Gram matrices.\n\nGram matrices are read from file, see last log for detail.\n'
453-
gmfile = np.load(results_dir + '/' + ds_name + '.gm.npz')
454+
gmfile = np.load(output_dir + '/' + ds_name + '.gm.npz')
454455
gram_matrices = gmfile['gms'] # a list to store gram matrices for all param_grid_precomputed
455456
gram_matrix_time = gmfile['gmtime'] # time used to compute the gram matrices
456457
param_list_pre_revised = gmfile['params'] # list to store param grids precomputed ignoring the useless ones
@@ -603,8 +604,8 @@ def init_worker(gms_toshare):
603604
str_fw += 'training time with hyper-param choices who did not participate in calculation of gram matrices: {:.2f}s\n\n'.format(tt_poster)
604605

605606
# open file to save all results for this dataset.
606-
if not os.path.exists(results_dir):
607-
os.makedirs(results_dir)
607+
if not os.path.exists(output_dir):
608+
os.makedirs(output_dir)
608609

609610
# print out results as table.
610611
str_fw += printResultsInTable(param_list, param_list_pre_revised, average_val_scores,
@@ -613,11 +614,11 @@ def init_worker(gms_toshare):
613614
model_type, verbose)
614615

615616
# open file to save all results for this dataset.
616-
if not os.path.exists(results_dir + '/' + ds_name + '.output.txt'):
617-
with open(results_dir + '/' + ds_name + '.output.txt', 'w') as f:
617+
if not os.path.exists(output_dir + '/' + ds_name + '.output.txt'):
618+
with open(output_dir + '/' + ds_name + '.output.txt', 'w') as f:
618619
f.write(str_fw)
619620
else:
620-
with open(results_dir + '/' + ds_name + '.output.txt', 'r+') as f:
621+
with open(output_dir + '/' + ds_name + '.output.txt', 'r+') as f:
621622
content = f.read()
622623
f.seek(0, 0)
623624
f.write(str_fw + '\n\n\n' + content)
@@ -797,7 +798,7 @@ def parallel_trial_do(param_list_pre_revised, param_list, y, model_type, trial):
797798

798799

799800
def compute_gram_matrices(dataset, y, estimator, param_list_precomputed,
800-
results_dir, ds_name,
801+
output_dir, ds_name,
801802
n_jobs=1, str_fw='', verbose=True):
802803
gram_matrices = [
803804
] # a list to store gram matrices for all param_grid_precomputed
@@ -867,7 +868,7 @@ def compute_gram_matrices(dataset, y, estimator, param_list_precomputed,
867868
# threshold=np.inf,
868869
# floatmode='unique') + '\n\n'
869870

870-
fig_file_name = results_dir + '/GM[ds]' + ds_name
871+
fig_file_name = output_dir + '/GM[ds]' + ds_name
871872
if params_out != {}:
872873
fig_file_name += '[params]' + str(idx)
873874
plt.imshow(Kmatrix)
@@ -897,8 +898,8 @@ def compute_gram_matrices(dataset, y, estimator, param_list_precomputed,
897898
return gram_matrices, gram_matrix_time, param_list_pre_revised, y, str_fw
898899

899900

900-
def read_gram_matrices_from_file(results_dir, ds_name):
901-
gmfile = np.load(results_dir + '/' + ds_name + '.gm.npz')
901+
def read_gram_matrices_from_file(output_dir, ds_name):
902+
gmfile = np.load(output_dir + '/' + ds_name + '.gm.npz')
902903
gram_matrices = gmfile['gms'] # a list to store gram matrices for all param_grid_precomputed
903904
param_list_pre_revised = gmfile['params'] # list to store param grids precomputed ignoring the useless ones
904905
y = gmfile['y'].tolist()

0 commit comments

Comments
 (0)