-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhw7_training.py
More file actions
320 lines (273 loc) · 12.5 KB
/
hw7_training.py
File metadata and controls
320 lines (273 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# -*- coding: utf-8 -*-
"""hw7_training.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1hqYLY-ku-kJGeMgg8akMp_10NVlXLTAM
"""
# Importing the necessary libraries
import numpy as np
import sys,os,os.path
import torch
import torch.nn as nn
import torch.nn.functional
import torchvision
import torchvision.transforms as tvt
import torch.optim as optim
import numpy as np
from PIL import ImageFilter
import numbers
import re
import cv2
import math
import random
import copy
import matplotlib.pyplot as plt
import gzip
import pickle
import logging
import requests
import torch.autograd as autograd
from torch.optim.lr_scheduler import StepLR
import torchvision.transforms as tvt
import numpy as np
import os
import argparse
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, Dataset
# Implementing the class for customized dataloader
class mydataloader(torch.utils.data.DataLoader):
def __init__(self):
self.image_path = os.listdir()
self.transform = tvt.Compose([tvt.ToTensor(), tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def __len__(self):
return len(self.image_path)
def __getitem__(self, idx):
temp_image = Image.open(self.image_path[idx])
temp_image = self.transform(temp_image).to(dtype = torch.float64)
return temp_image
# Discriminator class for 64x64 RGB images (see fc1 for image dimensions)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding = 1)
self.conv2 = nn.Conv2d(16, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 32, 3, padding=1)
self.pool1 = nn.Conv2d(16, 16, 4, stride=2, padding=1)
self.pool2 = nn.Conv2d(64, 64, 4, stride=2, padding=1)
self.pool3 = nn.Conv2d(32, 32, 4, stride=2, padding=1)
self.fc1 = nn.Linear((32*8*8), 256)
self.fc2 = nn.Linear(256, 1)
def forward(self, x):
x = self.pool1(torch.nn.functional.relu(self.conv1(x)))
x = self.pool2(torch.nn.functional.relu(self.conv2(x)))
x = self.pool3(torch.nn.functional.relu(self.conv3(x)))
x = x.view(-1, (32*8*8))
x = torch.nn.functional.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
return x
# Generator class for generating images that are of RGB and of 64x64
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.convt1 = nn.ConvTranspose2d(256, 64, 4, stride=1) #With random noise vectors of size 256x1x1
self.bn1 = nn.BatchNorm2d(64)
self.convt2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.convt3 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(16)
self.convt4 = nn.ConvTranspose2d(16, 8, 4, stride=2, padding=1)
self.bn4 = nn.BatchNorm2d(8)
self.convt5 = nn.ConvTranspose2d(8, 3, 4, stride=2, padding=1)
self.tanh = nn.Tanh()
def forward(self, x):
x = torch.nn.functional.relu(self.bn1(self.convt1(x)))
x = torch.nn.functional.relu(self.bn2(self.convt2(x)))
x = torch.nn.functional.relu(self.bn3(self.convt3(x)))
x = torch.nn.functional.relu(self.bn4(self.convt4(x)))
x = self.tanh(self.convt5(x))
return x
# Function for training (BCE-GAN and Wasserstein GAN)
def run_code_for_training_bce(netG, netD, data_loader, epochs = 10, batch_size = 10):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Move the generator and discriminator networks to the selected device
negG = netG.to(device)
netD = netD.to(device)
criterion = torch.nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0001) #Set up Optimizer for Discriminator
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002) #Set up Optimizer for Generator
loss_record = []
# Initialize lists to store the loss records for the Generator and the Discriminator
G_loss_record = []
D_loss_record = []
for epoch in range(epochs):
# print(f"Entering {epoch + 1} of {epochs} epochs")
#Initialize losses for Discriminator and Generator
loss_D = 0.0
loss_G = 0.0
# Loop over mini-batches
for iteration, data in enumerate(data_loader):
# print(f"Loading Iteration {iteration}")
img = data #Load real images
image = img.to(device, dtype=torch.float)
# Train discriminator with real images
optimizerD.zero_grad()
real_output = netD(image).squeeze()
real_label = torch.ones_like(real_output, device=device)
real_loss_D = criterion(real_output, real_label) #Calculate BCE loss
real_loss_D.backward() #Perform backpropagation
# Train discriminator with fake images
noise = torch.randn(batch_size, 256, 1, 1, device=device) #Generate noise input
fake_image = netG(noise)
fake_output = netD(fake_image.detach()).squeeze()
fake_label = torch.zeros_like(fake_output, device=device)
fake_loss_D = criterion(fake_output, fake_label) #Calculate BCE loss
fake_loss_D.backward()
# Update discriminator parameters using optimizer
loss_D = loss_D + real_loss_D.item() + fake_loss_D.item()
optimizerD.step()
# Train generator
optimizerG.zero_grad()
g_output = netD(fake_image).squeeze()
g_label = torch.ones_like(g_output, device=device)
g_loss_G = criterion(g_output, g_label) #Calculate Generator loss based on discriminator fakes
g_loss_G.backward()
# Update generator parameters using optimizer
loss_G += g_loss_G.item()
optimizerG.step()
# Print losses
i = 50
if(iteration+1) % i == 0:
running_loss = loss_D + loss_G
print("\n[epoch:%d, batch:%5d] loss: %.3f D_loss: %.3f G_loss: %.3f" %(epoch + 1, iteration + 1, running_loss / float(i), loss_D / float(i), loss_G / float(i)))
loss_record.append((running_loss / float(i)))
G_loss_record.append((loss_G / float(i)))
D_loss_record.append((loss_D / float(i)))
loss_G = 0.0
loss_D = 0.0
running_loss = 0.0
return netG, netD, loss_record, G_loss_record, D_loss_record
# Function for calculating gradient penalty for Wasserstein distance loss (inspired by pytorch CycleGAN and pix2pix)
def gradient_penalty(netD, real_data, fake_data, device):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
Returns the gradient penalty loss
"""
batch_size = real_data.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(device)
interpolates = (alpha * real_data + ((1 - alpha) * fake_data)).requires_grad_(True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates, device=device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(batch_size, -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def run_code_for_training_wasserstein(netG, netD, data_loader, n_critic=5, gp_lambda=10, epochs=10, batch_size=10):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = netG.to(device)
netD = netD.to(device)
optimizerD = torch.optim.Adam(netD.parameters(), lr=0.00002, betas=(0.5, 0.9))
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.00003, betas=(0.5, 0.9))
# Implementing learning rate schedulers to prevent plateau
schedulerD = StepLR(optimizerD, step_size=100, gamma=0.9)
schedulerG = StepLR(optimizerG, step_size=100, gamma=0.9)
loss_record = []
G_loss_record = []
D_loss_record = []
for epoch in range(epochs):
running_loss_D = 0.0
running_loss_G = 0.0
for iteration, data in enumerate(data_loader):
img = data
image = img.to(device, dtype=torch.float)
# Discriminator
for j in range(n_critic):
optimizerD.zero_grad()
# Train discriminator with real images
real_output = netD(image)
real_loss_D = -real_output.mean()
# Train discriminator with fake images
noise = torch.randn(image.size(0), 256, 1, 1, device=device) # image.size(0) so fake images generated are consistent
fake_image = netG(noise)
fake_output = netD(fake_image.detach())
fake_loss_D = -fake_output.mean()
# Applying gradient penalty function for computing loss
gp = gradient_penalty(netD, image, fake_image, device)
# Computing total discriminator loss
total_loss_D = real_loss_D + fake_loss_D + gp_lambda * gp
total_loss_D.backward(retain_graph=True)
optimizerD.step()
running_loss_D += total_loss_D.item()
# Generator
optimizerG.zero_grad()
g_output = netD(fake_image) # Train generator based on fake images
# Update generator parameters using optimizer
g_loss_G = -g_output.mean()
g_loss_G.backward()
optimizerG.step()
running_loss_G += g_loss_G.item()
if (iteration + 1) % 50 == 0:
print("\n[epoch:%d, batch:%5d] loss: %.3f D_loss: %.3f G_loss: %.3f" % (epoch + 1, iteration + 1, (running_loss_D + running_loss_G) / 50, running_loss_D / 50, running_loss_G / 50))
loss_record.append((running_loss_D + running_loss_G) / 50)
G_loss_record.append(running_loss_G / 50)
D_loss_record.append(running_loss_D / 50)
running_loss_G = 0.0
running_loss_D = 0.0
# Update the learning rates at the end of each epoch
schedulerD.step()
schedulerG.step()
return netG, netD, loss_record, G_loss_record, D_loss_record
# Code block for directory navigation
# If using Google Drive
os.chdir("/content/drive/MyDrive/BME 64600/hw7/pizzas/train")
print(os.getcwd())
list_images = os.listdir()
# print(list_images[:5]) #Confirming list of images
print(len(list_images))
# Main Script for Execution of Training
if __name__ == '__main__':
# Making sure that CUDA is available
print(torch.cuda.is_available())
data = mydataloader()
data_loader = torch.utils.data.DataLoader(dataset = data, batch_size = 50, shuffle = True)
# Train GAN with BCE Loss
netG_bce = Generator()
netD_bce = Discriminator()
netG_bce, netD_bce,loss_record_bce, G_loss_record_bce, D_loss_record_bce = run_code_for_training_bce(netG_bce, netD_bce, data_loader, epochs = 20, batch_size = 10)
# Train GAN with Wasserstein Loss
netG_wasserstein = Generator()
netD_wasserstein = Discriminator()
netG_wasserstein, netD_wasserstein, loss_record_wasserstein, G_loss_record_wasserstein, D_loss_record_wasserstein = run_code_for_training_wasserstein(netG_wasserstein, netD_wasserstein, data_loader, epochs = 20, batch_size = 10)
# Saving the Models
print(os.getcwd())
torch.save(netG_bce.state_dict(), '/content/drive/MyDrive/BME 64600/hw7/netG_bce')
torch.save(netD_bce.state_dict(), '/content/drive/MyDrive/BME 64600/hw7/netD_bce')
torch.save(netG_wasserstein.state_dict(), '/content/drive/MyDrive/BME 64600/hw7/netG_wasserstein')
torch.save(netD_wasserstein.state_dict(), '/content/drive/MyDrive/BME 64600/hw7/netD_wasserstein')
os.chdir("/content/drive/MyDrive/BME 64600/hw7")
# Plotting out the BCE loss figure
plt.figure(1)
plt.title("BCE GAN Loss vs. Iterations")
plt.plot(loss_record_bce, label = "Training Loss")
plt.plot(G_loss_record_bce, label = "Generator Loss")
plt.plot(D_loss_record_bce, label = "Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("BCE_train_loss.jpg")
# Plotting out the Wasserstein distance loss figure
plt.figure(2)
plt.title("Wasserstein GAN Loss vs Iterations")
plt.plot(loss_record_wasserstein, label = "Training Loss")
plt.plot(G_loss_record_wasserstein, label = "Generator Loss")
plt.plot(D_loss_record_wasserstein, label = "Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("Wasserstein_train_loss.jpg")
plt.show()