Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 34 additions & 65 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,54 @@

sys.path.append('./')

from yolo.net.yolo_tiny_net import YoloTinyNet
import tensorflow as tf
from yolo.net.yolo_tiny_net import YoloTinyNet
from tools.visualize import PredictionWindow
import tensorflow as tf
import cv2
import numpy as np

classes_name = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train","tvmonitor"]
import argparse

# setup CLI argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source',
help='either webcam or image',
default='image')

def process_predicts(predicts):
p_classes = predicts[0, :, :, 0:20]
C = predicts[0, :, :, 20:22]
coordinate = predicts[0, :, :, 22:]
args = parser.parse_args()

p_classes = np.reshape(p_classes, (7, 7, 1, 20))
C = np.reshape(C, (7, 7, 2, 1))
common_params = {'image_size': 448, 'num_classes': 20, 'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell': 2, 'weight_decay': 0.0005}

P = C * p_classes
window = PredictionWindow(common_params, net_params)
cap = cv2.VideoCapture(0)

#print P[5,1, 0, :]
def get_frame():
ret, frame = cap.read()

index = np.argmax(P)
height, width = frame.shape[:2]

index = np.unravel_index(index, P.shape)
x = (width - height) / 2

class_num = index[3]
frame = frame[:, x:width-x, :]
frame = cv2.resize(frame, (448, 448))

coordinate = np.reshape(coordinate, (7, 7, 2, 4))
stop = cv2.waitKey(1) & 0xFF == ord('q')

max_coordinate = coordinate[index[0], index[1], index[2], :]
if stop:
cap.release()
return stop, frame

xcenter = max_coordinate[0]
ycenter = max_coordinate[1]
w = max_coordinate[2]
h = max_coordinate[3]
def get_image():
image = cv2.imread('cat.jpg')
image = cv2.resize(image, (448, 448))

xcenter = (index[1] + xcenter) * (448/7.0)
ycenter = (index[0] + ycenter) * (448/7.0)

w = w * 448
h = h * 448
return True, image

xmin = xcenter - w/2.0
ymin = ycenter - h/2.0

xmax = xmin + w
ymax = ymin + h

return xmin, ymin, xmax, ymax, class_num

common_params = {'image_size': 448, 'num_classes': 20,
'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005}

net = YoloTinyNet(common_params, net_params, test=True)

image = tf.placeholder(tf.float32, (1, 448, 448, 3))
predicts = net.inference(image)

sess = tf.Session()

np_img = cv2.imread('cat.jpg')
resized_img = cv2.resize(np_img, (448, 448))
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)


np_img = np_img.astype(np.float32)

np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))

saver = tf.train.Saver(net.trainable_collection)

saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')

np_predict = sess.run(predicts, feed_dict={image: np_img})

xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()
if args.source == 'image':
window.run(get_image)
elif args.source == 'webcam':
window.run(get_frame)
else:
print('please define a valid source, either "webcam" or "image"')
95 changes: 95 additions & 0 deletions tools/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import cv2
import numpy as np
import tensorflow as tf
from yolo.net.yolo_tiny_net import YoloTinyNet

CLASS_NAMES = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

class PredictionWindow(object):
""" opens window showing the models predictions """
def __init__(self, common_params, net_params):
self.net = YoloTinyNet(common_params, net_params, test=True)

self.image_size = common_params['image_size']
self.image = tf.placeholder(tf.float32, (1, self.image_size, self.image_size, 3))
self.predicts = self.net.inference(self.image)

def run(self, source_callback):
sess = tf.Session()

saver = tf.train.Saver(self.net.trainable_collection)
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')

stop = False

while not stop:
stop, frame = source_callback()

orig = np.copy(frame)

frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = frame.astype(np.float32)
frame = frame / 255.0 * 2 - 1
frame = np.reshape(frame, (1, self.image_size, self.image_size, 3))

predictions = sess.run(self.predicts, feed_dict={self.image: frame})
boxes = self.process_predicts(predictions)

for xmin, ymin, xmax, ymax, class_num in boxes:
class_name = CLASS_NAMES[class_num]
cv2.rectangle(orig, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(orig, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))

cv2.imshow('model predictions', orig)

sess.close()
cv2.waitKey(0)
cv2.destroyAllWindows()


def process_predicts(self, predicts):
p_classes = predicts[0, :, :, 0:20]
C = predicts[0, :, :, 20:22]
coordinate = np.reshape(predicts[0, :, :, 22:], (7, 7, 2, 4))

p_classes = np.reshape(p_classes, (7, 7, 1, 20))
C = np.reshape(C, (7, 7, 2, 1))

P = C * p_classes

max_val = np.max(P)

boxes = []
for y in range(7):
for x in range(7):
classes = P[y, x]
index = np.argmax(classes)
index = np.unravel_index(index, classes.shape)

box_index, class_index = index

#print(box_index, class_index)

confidence = classes[box_index, class_index]

if confidence > max_val * 0.8:
class_num = class_index

cx, cy, w, h = coordinate[y, x, box_index, :]

cx = (x + cx) * (448 / 7.0)
cy = (y + cy) * (448 / 7.0)

w = w * 448
h = h * 448

xmin = cx - w / 2.0
ymin = cy - h / 2.0

xmax = xmin + w
ymax = ymin + h

boxes.append([xmin, ymin, xmax, ymax, class_index])

return boxes