Skip to content

Commit 9666f44

Browse files
Merge branch 'fani-lab:main' into main
2 parents 99ba9c0 + 4e33b37 commit 9666f44

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/mdl/fnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def learn(self, teamsvecs, splits, prev_model):
144144
y_ = self.model.forward(X)
145145
# if self.cfg.l == 'csl': csl_criterion(y_.squeeze(), y.squeeze(), index)
146146
# else:
147-
loss = self.bxe(y_, y)
147+
loss = self.bxe(y_, y).sum(dim=1).mean() #look at train loss for the reason
148148
if self.is_bayesian: loss += Fnn.btorch.get_kl_loss(self.model) / y.shape[0]
149149
#how about the loss of cdp for each class/expert? cdp_loss
150150
v_loss += loss.item()

src/mdl/ntf.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@ def evaluate(self, teamsvecs, splits, on_train=False, per_epoch=False, per_insta
4646

4747
predfiles = [f'{self.output}/f{foldidx}.{pred_set}.pred'] #the first file as a hook
4848
if per_epoch: predfiles += [f'{self.output}/{_}' for _ in os.listdir(self.output) if re.match(f'f{foldidx}.{pred_set}.e\d+.pred$', _)]
49-
for i, predfile in enumerate(sorted(sorted(predfiles), key=len)):
50-
epoch = f'e{i-1}.' if i > 0 else '' #the first file is non-epoch-based but the rest are
51-
filename = f'{self.output}/f{foldidx}.{pred_set}.{epoch}'
52-
Y_ = Ntf.torch.load(f'{filename}pred')['y_pred']
53-
log.info(f'Evaluating predictions at {filename}pred ... for {metrics}')
49+
for i, predfile in enumerate(sorted(sorted(predfiles), key=len)): #the first file is/should be non-epoch-based
50+
Y_ = Ntf.torch.load(predfile)['y_pred']
51+
log.info(f'Evaluating predictions at {predfile} ... for {metrics}')
5452

5553
log.info(f'{metrics.trec} ...')
5654
df, df_mean = metric.calculate_metrics(Y, Y_, per_instance, metrics)
@@ -59,7 +57,7 @@ def evaluate(self, teamsvecs, splits, on_train=False, per_epoch=False, per_insta
5957
log.info("['aucroc'] and curve values (fpr, tpr) ...")
6058
aucroc, fpr_tpr = metric.calculate_auc_roc(Y, Y_)
6159
df_mean.loc['aucroc'] = aucroc
62-
with open(f'{filename}pred.eval.roc.pkl', 'wb') as outfile: pickle.dump(fpr_tpr, outfile)
60+
with open(f'{predfile}.eval.roc.pkl', 'wb') as outfile: pickle.dump(fpr_tpr, outfile)
6361

6462
if (m:=[m for m in metrics.other if 'skill_coverage' in m]): #since this metric comes with topks str like 'skill_coverage_2,5,10'
6563
log.info(f'{m} ...')
@@ -70,9 +68,9 @@ def evaluate(self, teamsvecs, splits, on_train=False, per_epoch=False, per_insta
7068
df = pd.concat([df, df_skc], axis=0)
7169
df_mean = pd.concat([df_mean, df_mean_skc], axis=0)
7270

73-
if per_instance: df.to_csv(f'{filename}pred.eval.per_instance.csv', float_format='%.5f')
74-
log.info(f'Saving file per fold as {filename}pred.eval.mean.csv')
75-
df_mean.to_csv(f'{filename}pred.eval.mean.csv')
71+
if per_instance: df.to_csv(f'{predfile}.eval.per_instance.csv', float_format='%.5f')
72+
log.info(f'Saving file per fold as {predfile}.eval.mean.csv')
73+
df_mean.to_csv(f'{predfile}.eval.mean.csv')
7674
if i == 0: # non-epoch-based only, as there is different number of epochs for each fold model due to earlystopping
7775
fold_mean = pd.concat([fold_mean, df_mean], axis=1)
7876
if per_instance: fold_mean_per_instance = fold_mean_per_instance.add(df, fill_value=0)

0 commit comments

Comments
 (0)