-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDocClassifier.py
executable file
·95 lines (77 loc) · 3.41 KB
/
DocClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 15 10:05:14 2018
@author: deepak
"""
import numpy as np
from keras.optimizers import Adam, RMSprop
from keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from keras.layers import GRU, BatchNormalization, Conv1D, MaxPooling1D
from keras.models import Model, load_model
from RocAucCallback import RocAucCallback
from NLPModel import GRU_CNN_Model
class DocClassifier():
def __init__(
self,
embedding_matrix,
max_len,
train_batch_size,
test_batch_size,
epochs,
callbacks=None,
best_model_path="best_model.hdf5"
):
self.embedding_matrix = embedding_matrix
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.epochs = epochs
self.callbacks = callbacks
self.best_model_path = best_model_path
self.max_len = max_len
def get_pickable(self):
return {
'max_len': self.max_len,
'train_batch_size': self.train_batch_size,
'test_batch_size': self.test_batch_size,
'epochs': self.epochs,
'callbacks': self.callbacks,
'best_model_path': self.best_model_path
}
def load_pickable(self, pkl):
self.max_len = pkl['max_len']
self.train_batch_size = pkl['train_batch_size']
self.test_batch_size = pkl['test_batch_size']
self.epochs = pkl['epochs']
self.callbacks = pkl['callbacks']
self.best_model_path = pkl['best_model_path']
def callbacks_set(self, x_valid, y_valid):
if self.callbacks != None:
check_point = ModelCheckpoint(self.best_model_path, monitor = "val_loss", verbose = 1,
save_best_only = True, mode = "min")
ra_val = RocAucCallback(validation_data=(x_valid, y_valid), interval = 1)
early_stop = EarlyStopping(monitor = "val_loss", mode = "min", patience = 5)
self.callbacks = [ra_val, check_point, early_stop]
return self.callbacks
def fit(self, x_train, x_valid, y_train, y_valid):
"""
Fits model to data
"""
if np.any(np.isnan(x_train)):
print("x_train contains NaNs")
if np.any(np.isnan(y_train)):
print("y_train contains NaNs")
model = GRU_CNN_Model(
self.embedding_matrix,
max_features=self.embedding_matrix.shape[0],
embed_size=self.embedding_matrix.shape[1],
input_shape=self.max_len
)
self.model_history = model.fit(x_train, y_train, batch_size = self.train_batch_size,
epochs = self.epochs, validation_data = (x_valid, y_valid),
verbose = 1, callbacks = self.callbacks_set(x_valid, y_valid)
)
self.model_fit = load_model(self.best_model_path)
def predict(self,x_test):
y_test_pred = self.model_fit.predict(x_test, batch_size = self.test_batch_size, verbose = 1)
return y_test_pred