@@ -134,7 +134,7 @@ def train(train_loader, dev_loader, model, args):
134
134
accuracy = 100.0 * corrects / args .batch_size
135
135
print ('Epoch[{}] Batch[{}] - loss: {:.6f} lr: {:.5f} acc: {:.3f}% ({}/{})' .format (epoch ,
136
136
i_batch ,
137
- loss .data [ 0 ] ,
137
+ loss .data ,
138
138
optimizer .state_dict ()['param_groups' ][0 ]['lr' ],
139
139
accuracy ,
140
140
corrects ,
@@ -182,7 +182,7 @@ def eval(data_loader, model, epoch_train, batch_train, optimizer, args):
182
182
target = Variable (target )
183
183
logit = model (inputs )
184
184
predicates = torch .max (logit , 1 )[1 ].view (target .size ()).data
185
- accumulated_loss += F .nll_loss (logit , target , size_average = False ).data [ 0 ]
185
+ accumulated_loss += F .nll_loss (logit , target , size_average = False ).data
186
186
corrects += (torch .max (logit , 1 )[1 ].view (target .size ()).data == target .data ).sum ()
187
187
predicates_all += predicates .cpu ().numpy ().tolist ()
188
188
target_all += target .data .cpu ().numpy ().tolist ()
@@ -221,7 +221,8 @@ def main():
221
221
222
222
# load training data
223
223
print ("\n Loading training data..." )
224
- train_dataset = AGNEWs (label_data_path = args .train_path , alphabet_path = args .alphabet_path )
224
+
225
+ train_dataset = AGNEWs (label_data_path = args .train_path , alphabet_path = args .alphabet_path , l0 = args .l0 )
225
226
print ("Transferring training data into iterator..." )
226
227
train_loader = DataLoader (train_dataset , batch_size = args .batch_size , num_workers = args .num_workers , drop_last = True , shuffle = True )
227
228
@@ -230,7 +231,7 @@ def main():
230
231
231
232
# load developing data
232
233
print ("\n Loading developing data..." )
233
- dev_dataset = AGNEWs (label_data_path = args .val_path , alphabet_path = args .alphabet_path )
234
+ dev_dataset = AGNEWs (label_data_path = args .val_path , alphabet_path = args .alphabet_path , l0 = args . l0 )
234
235
print ("Transferring developing data into iterator..." )
235
236
dev_loader = DataLoader (dev_dataset , batch_size = args .batch_size , num_workers = args .num_workers , drop_last = True )
236
237
0 commit comments