-
Notifications
You must be signed in to change notification settings - Fork 0
/
cropping_faces.py
30 lines (28 loc) · 1.2 KB
/
cropping_faces.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
import os
from os import walk
import cv2
from tqdm import tqdm
def crop_face(datadir=None, path_to_save=None, img=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if img is None:
if not os.path.exists(path_to_save):
os.makedirs(path_to_save)
mtcnn = MTCNN(image_size=224, device=device)
for dirpath, dirnames, filenames in tqdm(walk(datadir)):
if dirnames:
for dirname in dirnames:
if not os.path.exists(os.path.join(path_to_save, dirname)):
os.makedirs(os.path.join(path_to_save, dirname))
if filenames:
for file in filenames:
class_type = dirpath.split('/')[-1]
image = cv2.imread(os.path.join(dirpath, file))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
save_path = os.path.join(path_to_save, class_type, file)
mtcnn(image, save_path=save_path)
else:
mtcnn = MTCNN(image_size=227, device=device)
face, prob = mtcnn(img, return_prob=True)
return face, prob