-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLSTM_with_temporalpooling.py
More file actions
179 lines (160 loc) · 9.18 KB
/
LSTM_with_temporalpooling.py
File metadata and controls
179 lines (160 loc) · 9.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
from sklearn.metrics import roc_auc_score
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader
from data import PatientDataset, collate_batch, load_splits
from utils import EarlyStopping
torch.backends.cudnn.benchmark = True
torch.manual_seed(42)
dataset_dir = ''
numerical_feature_list = []
categorical_feature_list = []
class LSTMPoolClassifier(nn.Module):
def __init__(self, input_dim=25, hidden_dim=128, lstm_layer=1, dropout_rate=0.2, pool_option='mean', pool_size=10):
super(LSTMPoolClassifier, self).__init__()
self.pool_option = pool_option
self.pool_size = pool_size
self.num_layers = lstm_layer
self.lstm = nn.LSTM(input_size=input_dim,
hidden_size=hidden_dim,
num_layers=lstm_layer,
batch_first=True,
dropout=dropout_rate if lstm_layer > 1 else 0,
bidirectional=True)
self.layer_norm = nn.LayerNorm(2 * hidden_dim)
self.classifier = nn.Sequential(nn.Linear(2 * hidden_dim, 32),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(32, 2))
def temporal_mean_pooling(self, x):
batch_size, seq_len, input_dim = x.size()
# Adjust seq_len to be divisible by pool_size
new_seq_len = seq_len - (seq_len % self.pool_size) if seq_len % self.pool_size != 0 else seq_len
if new_seq_len < seq_len:
x = x[:, :new_seq_len, :]
x = x.view(batch_size, new_seq_len // self.pool_size, self.pool_size, input_dim) # Reshape for pooling
x = torch.mean(x, dim=2)
return x
def temporal_max_pooling(self, x):
batch_size, seq_len, input_dim = x.size()
# Adjust seq_len to be divisible by pool_size
new_seq_len = seq_len - (seq_len % self.pool_size) if seq_len % self.pool_size != 0 else seq_len
if new_seq_len < seq_len:
x = x[:, :new_seq_len, :]
x = x.view(batch_size, new_seq_len // self.pool_size, self.pool_size, input_dim) # Reshape for pooling
x = torch.max(x, dim=2)[0] # Take the max across the pooling dimension
return x
def forward(self, x, lengths):
# x: (batch_size, max_len, 25), time series features
# lengths: (batch_size), actual sequence lengths
if self.pool_option == 'mean':
x = self.temporal_mean_pooling(x)
else:
x = self.temporal_max_pooling(x)
new_lengths = torch.div(lengths, self.pool_size, rounding_mode='floor') # Adjust lengths after pooling
# Pack the padded sequences to skip computation on padded steps
packed_x = pack_padded_sequence(x, new_lengths, batch_first=True, enforce_sorted=False)
output, (hn, cn) = self.lstm(packed_x) # hn shape: (num_layers, batch_size, hidden_size)
# Concatenate forward and backward final hidden states
forward_hidden = hn[2 * self.num_layers - 2] # Last forward layer
backward_hidden = hn[2 * self.num_layers - 1] # Last backward layer
hn_concat = torch.cat([forward_hidden, backward_hidden], dim=1) # shape: (batch_size, 2*hidden_dim)
hn_concat = self.layer_norm(hn_concat)
out = self.classifier(hn_concat)
return out
def run_model(fold_id=0, batch=64, num_epochs=100, warmup_epochs=5, learning=0.0003, decay=0.0001,
lstm_pool=10, pooling='mean', lstm_embedding=128, lstm_dropout=0.2, lstm_layer=1,
numerical_norm='zscore', ts_norm='zscore', model_pth='model.pth'):
splits_samples, splits_labels = load_splits(dataset_dir)
train_files = splits_samples[fold_id][0]
val_files = splits_samples[fold_id][1]
train_labels = splits_labels[fold_id][0]
val_labels = splits_labels[fold_id][1]
train_data = PatientDataset(patient_ids=train_files, dataset_dir=dataset_dir,
numerical_cols=numerical_feature_list, numerical_norm=numerical_norm,
categorical_cols=categorical_feature_list,
ts_norm=ts_norm)
val_data = PatientDataset(patient_ids=val_files, dataset_dir=dataset_dir,
numerical_cols=numerical_feature_list, numerical_norm=numerical_norm,
categorical_cols=categorical_feature_list,
ts_norm=ts_norm)
train_loader = DataLoader(train_data, batch_size=batch, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_data, batch_size=batch, shuffle=False, collate_fn=collate_batch)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMPoolClassifier(input_dim=25, hidden_dim=lstm_embedding, lstm_layer=lstm_layer,
dropout_rate=lstm_dropout, pool_option=pooling, pool_size=lstm_pool)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning, weight_decay=decay)
def warmup_lambda(ep):
if ep < warmup_epochs:
return (ep + 1) / warmup_epochs # Gradually increase LR
return 1.0 # Keep LR constant after warmup
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)
reduce_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5,
threshold=0.0001)
early_stopping = EarlyStopping(patience=10, delta=0, path=model_pth, verbose=False)
res_train_auc, res_val_auc = 0, 0
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
train_corrects, train_total = 0, 0
train_labels, train_predicted, train_possibilities = [], [], []
for inputs_num, inputs_cat, inputs_ts, ts_lengths, labels in train_loader:
inputs_num, inputs_cat = inputs_num.to(device), inputs_cat.to(device)
inputs_ts = inputs_ts.to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
optimizer.zero_grad()
outputs = model(inputs_ts, ts_lengths, inputs_num, inputs_cat)
loss = criterion(outputs, labels)
loss.backward()
# grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Apply gradient clipping
optimizer.step()
train_loss += loss.item() * labels.size(0)
predictions = torch.max(outputs, 1)[1] # for Multiclass Classification
probabilities = torch.softmax(outputs, dim=1)[:, 1] # Probability for class 1
train_labels.extend(labels.cpu().numpy())
train_predicted.extend(predictions.view(-1).cpu().numpy())
train_possibilities.extend(probabilities.detach().cpu().numpy())
train_corrects += (predictions == labels).sum().item()
train_total += labels.size(0)
train_loss = train_loss / len(train_data)
train_acc = train_corrects / train_total
model.eval()
val_loss = 0.0
val_corrects, val_total = 0, 0
val_labels, val_predicted, val_possibilities = [], [], []
with torch.no_grad():
for inputs_num, inputs_cat, inputs_ts, ts_lengths, labels in val_loader:
inputs_num, inputs_cat = inputs_num.to(device), inputs_cat.to(device)
inputs_ts = inputs_ts.to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
outputs = model(inputs_ts, ts_lengths, inputs_num, inputs_cat)
loss = criterion(outputs, labels)
val_loss += loss.item() * labels.size(0)
predictions = torch.max(outputs, 1)[1] # for Multiclass Classification
probabilities = torch.softmax(outputs, dim=1)[:, 1] # Probability for class 1
val_labels.extend(labels.cpu().numpy())
val_predicted.extend(predictions.view(-1).cpu().numpy())
val_possibilities.extend(probabilities.detach().cpu().numpy())
val_corrects += (predictions == labels).sum().item()
val_total += labels.size(0)
val_loss = val_loss / len(val_data)
val_acc = val_corrects / val_total
print(epoch, sum(train_predicted), sum(val_predicted), train_predicted)
# Calculate AUC: the probability of the positive class for AUC, not the class predictions
train_auc = roc_auc_score(train_labels, train_possibilities)
val_auc = roc_auc_score(val_labels, val_possibilities)
res_train_auc = max(res_train_auc, train_auc)
res_val_auc = max(res_val_auc, val_auc)
res_log = f'{train_loss:.4f} {val_loss:.4f} {train_acc:.4f} {val_acc:.4f} {train_auc:.4f} {val_auc:.4f}'
print(f'{fold_id} Ep {epoch + 1}/{num_epochs} {res_log} {train_auc:.4f} {val_auc:.4f}')
if epoch < warmup_epochs:
warmup_scheduler.step()
else:
reduce_scheduler.step(val_loss)
early_stopping(val_loss, model)
if early_stopping.early_stop:
print('Early stopping triggered!')
break