Skip to content

Commit f3de3cc

Browse files
committed
support pytorch >= 0.5, fix issue #7
1 parent dafc4df commit f3de3cc

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

train.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def train(train_loader, dev_loader, model, args):
134134
accuracy = 100.0 * corrects/args.batch_size
135135
print('Epoch[{}] Batch[{}] - loss: {:.6f} lr: {:.5f} acc: {:.3f}% ({}/{})'.format(epoch,
136136
i_batch,
137-
loss.data[0],
137+
loss.data,
138138
optimizer.state_dict()['param_groups'][0]['lr'],
139139
accuracy,
140140
corrects,
@@ -182,7 +182,7 @@ def eval(data_loader, model, epoch_train, batch_train, optimizer, args):
182182
target = Variable(target)
183183
logit = model(inputs)
184184
predicates = torch.max(logit, 1)[1].view(target.size()).data
185-
accumulated_loss += F.nll_loss(logit, target, size_average=False).data[0]
185+
accumulated_loss += F.nll_loss(logit, target, size_average=False).data
186186
corrects += (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
187187
predicates_all+=predicates.cpu().numpy().tolist()
188188
target_all+=target.data.cpu().numpy().tolist()
@@ -221,7 +221,8 @@ def main():
221221

222222
# load training data
223223
print("\nLoading training data...")
224-
train_dataset = AGNEWs(label_data_path=args.train_path, alphabet_path=args.alphabet_path)
224+
225+
train_dataset = AGNEWs(label_data_path=args.train_path, alphabet_path=args.alphabet_path, l0=args.l0)
225226
print("Transferring training data into iterator...")
226227
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True, shuffle=True)
227228

@@ -230,7 +231,7 @@ def main():
230231

231232
# load developing data
232233
print("\nLoading developing data...")
233-
dev_dataset = AGNEWs(label_data_path=args.val_path, alphabet_path=args.alphabet_path)
234+
dev_dataset = AGNEWs(label_data_path=args.val_path, alphabet_path=args.alphabet_path, l0=args.l0)
234235
print("Transferring developing data into iterator...")
235236
dev_loader = DataLoader(dev_dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True)
236237

0 commit comments

Comments
 (0)