Skip to content

Commit fb8e0b2

Browse files
authored
Merge pull request #40 from timesler/expose_facial_landmarks_from_mtcnn
Expose MTCNN facial landmarks in .detect() method
2 parents 95c737f + d8f7fa6 commit fb8e0b2

File tree

3 files changed

+116
-107
lines changed

3 files changed

+116
-107
lines changed

models/mtcnn.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -299,66 +299,84 @@ def forward(self, img, save_path=None, return_prob=False):
299299
else:
300300
return faces
301301

302-
def detect(self, img):
303-
"""Detect all faces in PIL image and return bounding boxes.
302+
def detect(self, img, landmarks=False):
303+
"""Detect all faces in PIL image and return bounding boxes and optional facial landmarks.
304304
305305
This method is used by the forward method and is also useful for face detection tasks
306-
that require lower-level handling of bounding boxes (e.g., face tracking). The
307-
functionality of the forward function can be emulated by using this method followed by
308-
the extract_face() function.
306+
that require lower-level handling of bounding boxes and facial landmarks (e.g., face
307+
tracking). The functionality of the forward function can be emulated by using this method
308+
followed by the extract_face() function.
309309
310310
Arguments:
311311
img {PIL.Image or list} -- A PIL image or a list of PIL images.
312+
313+
Keyword Arguments:
314+
landmarks {bool} -- Whether to return facial landmarks in addition to bounding boxes.
315+
(default: {False})
312316
313317
Returns:
314318
tuple(numpy.ndarray, list) -- For N detected faces, a tuple containing an
315319
Nx4 array of bounding boxes and a length N list of detection probabilities.
316320
Returned boxes will be sorted in descending order by detection probability if
317321
self.select_largest=False, otherwise the largest face will be returned first.
318322
If `img` is a list of images, the items returned have an extra dimension
319-
(batch) as the first dimension.
323+
(batch) as the first dimension. Optionally, a third item, the facial landmarks,
324+
are returned if `landmarks=True`.
320325
321326
Example:
322327
>>> from PIL import Image, ImageDraw
323328
>>> from facenet_pytorch import MTCNN, extract_face
324329
>>> mtcnn = MTCNN(keep_all=True)
325-
>>> boxes, probs = mtcnn.detect(img)
330+
>>> boxes, probs, points = mtcnn.detect(img, landmarks=True)
326331
>>> # Draw boxes and save faces
327332
>>> img_draw = img.copy()
328333
>>> draw = ImageDraw.Draw(img_draw)
329-
>>> for i, box in enumerate(boxes):
330-
... draw.rectangle(box.tolist())
334+
>>> for i, (box, point) in enumerate(zip(boxes, points)):
335+
... draw.rectangle(box.tolist(), width=5)
336+
... for p in point:
337+
... draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10)
331338
... extract_face(img, box, save_path='detected_face_{}.png'.format(i))
332339
>>> img_draw.save('annotated_faces.png')
333340
"""
334341

335342
with torch.no_grad():
336-
batch_boxes = detect_face(
343+
batch_boxes, batch_points = detect_face(
337344
img, self.min_face_size,
338345
self.pnet, self.rnet, self.onet,
339346
self.thresholds, self.factor,
340347
self.device
341348
)
342349

343-
boxes, probs = [], []
344-
for box in batch_boxes:
350+
boxes, probs, points = [], [], []
351+
for box, point in zip(batch_boxes, batch_points):
345352
box = np.array(box)
353+
point = np.array(point)
346354
if len(box) == 0:
347355
boxes.append(None)
348356
probs.append([None])
357+
points.append(None)
349358
elif self.select_largest:
350-
box = box[np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1]]
359+
box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1]
360+
box = box[box_order]
361+
point = point[box_order]
351362
boxes.append(box[:, :4])
352363
probs.append(box[:, 4])
364+
points.append(point)
353365
else:
354366
boxes.append(box[:, :4])
355367
probs.append(box[:, 4])
368+
points.append(point)
356369
boxes = np.array(boxes)
357370
probs = np.array(probs)
371+
points = np.array(points)
358372

359373
if not isinstance(img, Iterable):
360374
boxes = boxes[0]
361375
probs = probs[0]
376+
points = points[0]
377+
378+
if landmarks:
379+
return boxes, probs, points
362380

363381
return boxes, probs
364382

0 commit comments

Comments
 (0)