diff --git a/.gitignore b/.gitignore index 9ad7fb552..1ec8cb945 100644 --- a/.gitignore +++ b/.gitignore @@ -8,9 +8,6 @@ data !data/vgg/download-and-align.py !data/download-lfw-subset.sh -models/facenet/*.t7 -models/dlib/shape_predictor_68_face_landmarks.dat - *.pyc *.mp4 @@ -21,10 +18,17 @@ evaluation/attic/*/*.csv evaluation/attic/*/*.pdf demos/web/bower_components -demos/web/unknown*.npy -models/openface/*.t7 -models/openface/*.pkl celeb-classifier* site +dist +openface.egg-info + +**/.idea +**/*.t7 +**/*.pt +**/*.pkl +**/*.dat +**/*.npy +**/*.png diff --git a/Dockerfile b/Dockerfile index 9e198d7f7..88e1632f4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,35 +1,42 @@ -FROM bamos/ubuntu-opencv-dlib-torch:ubuntu_14.04-opencv_2.4.11-dlib_19.0-torch_2016.07.12 -MAINTAINER Brandon Amos +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 -# TODO: Should be added to opencv-dlib-torch image. -RUN ln -s /root/torch/install/bin/* /usr/local/bin +ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y \ curl \ git \ - graphicsmagick \ - libssl-dev \ - libffi-dev \ - python-dev \ - python-pip \ - python-numpy \ - python-nose \ - python-scipy \ - python-pandas \ - python-protobuf \ - python-openssl \ + software-properties-common \ + build-essential \ + cmake \ + pkg-config \ + python3 \ + python3-dev \ + python3-distutils \ + python3-pip \ + python3-opencv \ wget \ zip \ + libatlas-base-dev \ + libboost-all-dev \ + libopenblas-dev \ + liblapack-dev \ + libswscale-dev \ + libssl-dev \ + libffi-dev \ + libsm6 \ + libxext6 \ + libxrender1 \ && apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* ADD . /root/openface -RUN python -m pip install --upgrade --force pip -RUN cd ~/openface && \ - ./models/get-models.sh && \ - pip2 install -r requirements.txt && \ - python2 setup.py install && \ - pip2 install --user --ignore-installed -r demos/web/requirements.txt && \ - pip2 install -r training/requirements.txt +RUN python3 -m pip install --upgrade pip + +WORKDIR /root/openface + +RUN ./models/get-models.sh && \ + python3 -m pip install -r requirements.txt && \ + python3 -m pip install . +# python3 -m pip install --user --ignore-installed -r demos/web/requirements.txt && \ +# python3 -m pip install -r training/requirements.txt EXPOSE 8000 9000 -CMD /bin/bash -l -c '/root/openface/demos/web/start-servers.sh' diff --git a/batch-represent/batch_represent.py b/batch-represent/batch_represent.py new file mode 100644 index 000000000..9a45a6358 --- /dev/null +++ b/batch-represent/batch_represent.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# +# Copyright 2015-2024 Carnegie Mellon University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import functools +import os +from collections import Counter + +import cv2 +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset, DataLoader + +import openface + +SUPPORTED_IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'bmp', 'gif'] +REPS_CSV_FILE = 'reps.csv' +LABELS_CSV_FILE = 'labels.csv' +PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +MODEL_DIR = os.path.join(PROJECT_DIR, 'models') +DEFAULT_DLIB_FACE_PREDICTOR_PATH = os.path.join(MODEL_DIR, 'dlib', 'shape_predictor_68_face_landmarks.dat') +DEFAULT_DLIB_FACE_DETECTOR_PATH = os.path.join(MODEL_DIR, 'dlib', 'mmod_human_face_detector.dat') +DEFAULT_OPENFACE_MODEL_PATH = os.path.join(MODEL_DIR, 'openface', 'nn4.small2.v1.pt') +IMG_DIM = 96 + + +class OpenFaceDataset(Dataset): + def __init__(self, aligned_dataset_dir, annotations_file=None, transform=None, target_transform=None): + self.dataset_dir = aligned_dataset_dir + if annotations_file is None: + class_folders = [sub.name for sub in os.scandir(aligned_dataset_dir) if sub.is_dir()] + img_label_list = [] + for class_name in class_folders: + class_path = os.path.join(aligned_dataset_dir, class_name) + for img in os.scandir(class_path): + if img.name.lower().split('.')[-1] in SUPPORTED_IMAGE_EXTENSIONS: + img_label_list.append({'filename': os.path.join(class_path, img.name), + 'label': class_name}) + self.img_labels = pd.DataFrame(img_label_list) + else: + self.img_labels = pd.read_csv(annotations_file) + self.transform = transform + self.target_transform = target_transform + + def __len__(self): + return len(self.img_labels) + + def __getitem__(self, idx): + img_path = self.img_labels.iloc[idx, 0] + bgr_img = cv2.imread(img_path) + if bgr_img is None: + raise Exception('Unable to load image: {}'.format(img_path)) + rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + label = self.img_labels.iloc[idx, 1] + if self.transform: + rgb_img = self.transform(rgb_img) + if self.target_transform: + label = self.target_transform(label) + return rgb_img, label + + +def transform_image(image): + if image is None: + return None + image = (image / 255.).astype(np.float32) + image = np.transpose(image, (2, 0, 1)) # channel-first ordering + return image + + +def get_or_add(key, dictionary): + if key in dictionary: + return dictionary.get(key) + else: + val = len(dictionary) + 1 + dictionary[key] = val + return val + + +def align_all_images(raw_dataset_dir, align_dir, align, landmark_indices, skip_multi=False): + class_folders = [sub.name for sub in os.scandir(raw_dataset_dir) if sub.is_dir()] + print('=== Detecting and aligning faces ===') + summary_str = '{:<16}{:>8}\n'.format('Name', 'Count') + summary_str += '-' * 24 + for class_name in class_folders: + raw_class_path = os.path.join(raw_dataset_dir, class_name) + aligned_class_path = os.path.join(align_dir, class_name) + os.makedirs(aligned_class_path, exist_ok=True) + aligned_count = 0 + for img in os.scandir(raw_class_path): + if img.name.lower().split('.')[-1] in SUPPORTED_IMAGE_EXTENSIONS: + img_path = os.path.join(raw_class_path, img.name) + bgr_img = cv2.imread(img_path) + if bgr_img is None: + print('Warning: Unable to load image: {}'.format(img_path)) + continue + rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + aligned_rgb_img = align.align(IMG_DIM, rgb_img, landmarkIndices=landmark_indices, + skipMulti=skip_multi) + if aligned_rgb_img is None: + print('Warning: Unable to find a face: {}'.format(img_path)) + continue + aligned_bgr_img = cv2.cvtColor(aligned_rgb_img, cv2.COLOR_RGB2BGR) + aligned_img_path = os.path.join(aligned_class_path, img.name) + cv2.imwrite(aligned_img_path, aligned_bgr_img) + aligned_count += 1 + summary_str += '\n{:<16}{:>8}'.format(class_name, aligned_count) + print(summary_str) + + +def main(args): + input_dataset_dir = args.input_dir + output_csv_dir = args.csv_out + reps_csv_path = os.path.join(output_csv_dir, REPS_CSV_FILE) + labels_csv_path = os.path.join(output_csv_dir, LABELS_CSV_FILE) + os.makedirs(output_csv_dir, exist_ok=True) + for csv_path in [reps_csv_path, labels_csv_path]: + if os.path.exists(csv_path): + os.remove(csv_path) + if args.aligned: + dataset = OpenFaceDataset(input_dataset_dir, transform=transform_image) + else: + output_align_dir = args.align_out + os.makedirs(output_align_dir, exist_ok=True) + if args.dlib_face_detector_type == 'CNN': + align = openface.AlignDlib(args.dlib_face_predictor_path, args.dlib_face_detector_path, + upsample=args.upsample) + else: + align = openface.AlignDlib(args.dlib_face_predictor_path, upsample=args.upsample) + landmark_map = { + 'outerEyesAndNose': openface.AlignDlib.OUTER_EYES_AND_NOSE, + 'innerEyesAndBottomLip': openface.AlignDlib.INNER_EYES_AND_BOTTOM_LIP + } + if args.landmarks not in landmark_map: + raise Exception('Landmarks unrecognized: {}'.format(args.landmarks)) + landmark_indices = landmark_map[args.landmarks] + + align_all_images(input_dataset_dir, output_align_dir, align, landmark_indices, args.skip_multi) + dataset = OpenFaceDataset(output_align_dir, transform=transform_image) + + dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=args.shuffle, num_workers=args.worker) + model = openface.OpenFaceNet() + if args.cpu: + model.load_state_dict(torch.load(args.openface_model_path)) + else: + model.load_state_dict(torch.load(args.openface_model_path, map_location='cuda')) + model.to(torch.device('cuda')) + model.eval() + + label_dict = {} + label_counter = Counter() + for step, (images, labels) in enumerate(dataloader): + print('=== Generating representations for batch {}/{} ==='.format(step, len(dataloader))) + if not args.cpu: + images = images.to(torch.device('cuda')) + reps = model(images) + reps = reps.cpu().detach().numpy() + + with open(reps_csv_path, 'a') as reps_file: + np.savetxt(reps_file, reps, fmt='%.8f', delimiter=',') + + label_counter.update(labels) + with open(labels_csv_path, 'a') as labels_file: + for label in labels: + labels_file.write('{},{}\n'.format(get_or_add(label, label_dict), label)) + print('Summary: Representations generated for {} images in total'.format(sum(label_counter.values()))) + print(dict(label_counter)) + print('Saving csv files to folder: "{}"'.format(output_csv_dir)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_dir', required=True, type=str, help='path to image dataset directory') + parser.add_argument('-o', '--csv_out', required=True, type=str, help='path to csv output directory') + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--aligned', action='store_true', + help='if flag is set, assume input images are already aligned') + group.add_argument('--align_out', type=str, help='save aligned images to the specified directory') + parser.add_argument('--dlib_face_predictor_path', type=str, default=DEFAULT_DLIB_FACE_PREDICTOR_PATH, + help='path to dlib face predictor model') + parser.add_argument('--dlib_face_detector_type', type=str, choices=['HOG', 'CNN'], default='CNN', + help='type of dlib face detector to be used') + parser.add_argument('--dlib_face_detector_path', type=str, default=DEFAULT_DLIB_FACE_DETECTOR_PATH, + help='path to dlib CNN face detector model') + parser.add_argument('--upsample', type=int, default=1, help="number of times to upsample images before detection.") + parser.add_argument('--openface_model_path', type=str, default=DEFAULT_OPENFACE_MODEL_PATH, + help='path to pretrained OpenFace model') + parser.add_argument('--batch', type=int, default=64, help='batch size') + parser.add_argument('--worker', type=int, default=4, help='number of workers') + parser.add_argument('--shuffle', action='store_true', help='shuffle dataset') + parser.add_argument('--skip_multi', action='store_true', help='if flag is set, skip image if multiple faces are' + 'found, otherwise only use the largest face') + parser.add_argument('--landmarks', type=str, choices=['outerEyesAndNose', 'innerEyesAndBottomLip'], + default='outerEyesAndNose', help='landmarks to align to') + parser.add_argument('--cpu', action='store_true', help='run OpenFace model on CPU only') + arguments = parser.parse_args() + + main(arguments) diff --git a/conversion/convert_to_pytorch.py b/conversion/convert_to_pytorch.py new file mode 100644 index 000000000..288e732c6 --- /dev/null +++ b/conversion/convert_to_pytorch.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# +# Copyright 2015-2024 Carnegie Mellon University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torchfile + +import openface + + +def get_param_tensor(layer_dict, param_name): + return torch.from_numpy(layer_dict[param_name]).float() + + +def copy_conv_layer_params(from_layer, to_layer, reshape_weight=None): + weight = get_param_tensor(from_layer, b'weight') + if reshape_weight is not None: + weight = torch.reshape(weight, reshape_weight) + bias = get_param_tensor(from_layer, b'bias') + to_layer.weight.copy_(weight) + to_layer.bias.copy_(bias) + + +def copy_bn_layer_params(from_layer, to_layer): + to_layer.weight.copy_(get_param_tensor(from_layer, b'weight')) + to_layer.bias.copy_(get_param_tensor(from_layer, b'bias')) + to_layer.running_mean.copy_(get_param_tensor(from_layer, b'running_mean')) + to_layer.running_var.copy_(get_param_tensor(from_layer, b'running_var')) + + +def copy_inception_params(from_modules, to_modules, conv_layers_indices, bn_layers_indices): + for branch_ind, layer_ind in conv_layers_indices: + copy_conv_layer_params(from_modules[branch_ind][layer_ind], to_modules[branch_ind][layer_ind]) + for branch_ind, layer_ind in bn_layers_indices: + copy_bn_layer_params(from_modules[branch_ind][layer_ind], to_modules[branch_ind][layer_ind]) + + +if __name__ == '__main__': + openface_model = openface.OpenFaceNet() + + # Load weights from lua torch model + # Load the pretrained model first with the latest LuaTorch and re-save it + lua_model = torchfile.load('nn4.small2.v1.resaved.t7') + lua_model_layers = lua_model[b'modules'] + + with torch.no_grad(): + copy_conv_layer_params(lua_model_layers[0], openface_model.conv1, (64, 3, 7, 7)) + copy_bn_layer_params(lua_model_layers[1], openface_model.bn1) + copy_conv_layer_params(lua_model_layers[5], openface_model.conv2, (64, 64, 1, 1)) + copy_bn_layer_params(lua_model_layers[6], openface_model.bn2) + copy_conv_layer_params(lua_model_layers[8], openface_model.conv3, (192, 64, 3, 3)) + copy_bn_layer_params(lua_model_layers[9], openface_model.bn3) + + incept3a_modules = [branch[b'modules'] for branch in lua_model_layers[13][b'modules'][0][b'modules']] + copy_inception_params(incept3a_modules, openface_model.incept3a.branches, + conv_layers_indices=((0, 0), (0, 3), (1, 0), (1, 3), (2, 1), (3, 0)), + bn_layers_indices=((0, 1), (0, 4), (1, 1), (1, 4), (2, 2), (3, 1))) + + incept3b_modules = [branch[b'modules'] for branch in lua_model_layers[14][b'modules'][0][b'modules']] + copy_inception_params(incept3b_modules, openface_model.incept3b.branches, + conv_layers_indices=((0, 0), (0, 3), (1, 0), (1, 3), (2, 1), (3, 0)), + bn_layers_indices=((0, 1), (0, 4), (1, 1), (1, 4), (2, 2), (3, 1))) + + incept3c_modules = [branch[b'modules'] for branch in lua_model_layers[15][b'modules'][0][b'modules']] + copy_inception_params(incept3c_modules, openface_model.incept3c.branches, + conv_layers_indices=((0, 0), (0, 3), (1, 0), (1, 3)), + bn_layers_indices=((0, 1), (0, 4), (1, 1), (1, 4))) + + incept4a_modules = [branch[b'modules'] for branch in lua_model_layers[16][b'modules'][0][b'modules']] + copy_inception_params(incept4a_modules, openface_model.incept4a.branches, + conv_layers_indices=((0, 0), (0, 3), (1, 0), (1, 3), (2, 1), (3, 0)), + bn_layers_indices=((0, 1), (0, 4), (1, 1), (1, 4), (2, 2), (3, 1))) + + incept4e_modules = [branch[b'modules'] for branch in lua_model_layers[17][b'modules'][0][b'modules']] + copy_inception_params(incept4e_modules, openface_model.incept4e.branches, + conv_layers_indices=((0, 0), (0, 3), (1, 0), (1, 3)), + bn_layers_indices=((0, 1), (0, 4), (1, 1), (1, 4))) + + incept5a_modules = [branch[b'modules'] for branch in lua_model_layers[18][b'modules'][0][b'modules']] + copy_inception_params(incept5a_modules, openface_model.incept5a.branches, + conv_layers_indices=((0, 0), (0, 3),(1, 1), (2, 0)), + bn_layers_indices=((0, 1), (0, 4), (1, 2), (2, 1))) + + incept5b_modules = [branch[b'modules'] for branch in lua_model_layers[20][b'modules'][0][b'modules']] + copy_inception_params(incept5b_modules, openface_model.incept5b.branches, + conv_layers_indices=((0, 0), (0, 3),(1, 1), (2, 0)), + bn_layers_indices=((0, 1), (0, 4), (1, 2), (2, 1))) + + openface_model.ln.weight.copy_(get_param_tensor(lua_model_layers[24], b'weight')) + openface_model.ln.bias.copy_(get_param_tensor(lua_model_layers[24], b'bias')) + + # Run forward pass + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + openface_model = openface_model.to(device) + openface_model.eval() + ones = torch.ones((1, 3, 96, 96), dtype=torch.float32) + ones = ones.to(device) + pytorch_out = openface_model(ones).squeeze(0) + + # Compare results with lua torch model results + lua_out = torchfile.load('l26out.t7') + + np.testing.assert_allclose(pytorch_out.cpu().detach().numpy(), lua_out, rtol=1e-03, atol=1e-05) + + # Save model state dict + openface_model.to(torch.device('cpu')) + torch.save(openface_model.state_dict(), 'nn4.small2.v1.pt') diff --git a/conversion/test_luatorch.lua b/conversion/test_luatorch.lua new file mode 100644 index 000000000..5352f5dc6 --- /dev/null +++ b/conversion/test_luatorch.lua @@ -0,0 +1,18 @@ +require 'nn' +require 'dpnn' + +torch.setdefaulttensortype('torch.FloatTensor') +model = torch.load('/root/openface/models/openface/nn4.small2.v1.t7') +model.save('nn4.small2.v1.resaved.t7') + +model:evaluate() +ones = torch.ones(1,3,96,96) +x = model:forward(ones) +torch.save('l'..tostring(#model.modules)..'out.t7',x) + + +for i=#model.modules,2,-1 do + model:remove(i) + x = model:forward(ones) + torch.save('l'..tostring(i-1)..'out.t7',x) +end diff --git a/demos/classifier.py b/demos/classifier.py index e671ab883..63cfa35a9 100755 --- a/demos/classifier.py +++ b/demos/classifier.py @@ -198,7 +198,7 @@ def infer(args, multiple=False): confidence)) else: print("Predict {} with {:.2f} confidence.".format(person.decode('utf-8'), confidence)) - if isinstance(clf, GMM): + if isinstance(clf, mixture.GaussianMixture): dist = np.linalg.norm(rep - clf.means_[maxI]) print(" + Distance from the mean: {}".format(dist)) diff --git a/demos/classifier_new.py b/demos/classifier_new.py new file mode 100755 index 000000000..51d0bb96f --- /dev/null +++ b/demos/classifier_new.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# +# Copyright 2015-2024 Carnegie Mellon University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +start = time.time() + +import argparse +import cv2 +import os +import pickle +import sys + +from operator import itemgetter + +import numpy as np +np.set_printoptions(precision=2) +import pandas as pd +import torch + +import openface + +from sklearn.pipeline import Pipeline +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA +from sklearn.preprocessing import LabelEncoder +from sklearn.svm import SVC +from sklearn.model_selection import GridSearchCV +from sklearn import mixture +from sklearn.tree import DecisionTreeClassifier +from sklearn.naive_bayes import GaussianNB + +PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +MODEL_DIR = os.path.join(PROJECT_DIR, 'models') +DLIB_MODEL_DIR = os.path.join(MODEL_DIR, 'dlib') +OPENFACE_MODEL_DIR = os.path.join(MODEL_DIR, 'openface') +IMG_DIM = 96 + + +def get_rep(img_path, multiple=False): + start = time.time() + bgr_img = cv2.imread(img_path) + if bgr_img is None: + raise Exception('Unable to load image: {}'.format(img_path)) + + rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + + if args.verbose: + print(' + Original size: {}'.format(rgb_img.shape)) + if args.verbose: + print('Loading the image took {} seconds.'.format(time.time() - start)) + + start = time.time() + + if multiple: + bbs = align.getAllFaceBoundingBoxes(rgb_img) + else: + bb1 = align.getLargestFaceBoundingBox(rgb_img) + bbs = [bb1] + if len(bbs) == 0 or (not multiple and bb1 is None): + raise Exception('Unable to find a face: {}'.format(img_path)) + if args.verbose: + print('Face detection took {} seconds.'.format(time.time() - start)) + + reps = [] + for bb in bbs: + start = time.time() + aligned_face = align.align(IMG_DIM, rgb_img, bb, landmarkIndices=openface.AlignDlib.OUTER_EYES_AND_NOSE) + if aligned_face is None: + raise Exception('Unable to align image: {}'.format(img_path)) + if args.verbose: + print('Alignment took {} seconds.'.format(time.time() - start)) + print('This bbox is centered at {}, {}'.format(bb.center().x, bb.center().y)) + + start = time.time() + + aligned_face = (aligned_face / 255.).astype(np.float32) + aligned_face = np.expand_dims(np.transpose(aligned_face, (2, 0, 1)), axis=0) # BCHW order + aligned_face = torch.from_numpy(aligned_face) + if not args.cpu: + aligned_face = aligned_face.to(torch.device('cuda')) + rep = model(aligned_face) + rep = rep.cpu().detach().numpy().squeeze(0) + if args.verbose: + print('Neural network forward pass took {} seconds.'.format( + time.time() - start)) + reps.append((bb.center().x, rep)) + sreps = sorted(reps, key=lambda x: x[0]) + return sreps + + +def train(args): + print('Loading embeddings.') + labels_file = os.path.join(args.workDir, 'labels.csv') + labels = pd.read_csv(labels_file, header=None).values[:, 1] + labels = np.array(labels) + reps_file = os.path.join(args.workDir, 'reps.csv') + embeddings = pd.read_csv(reps_file, header=None).values + le = LabelEncoder().fit(labels) + labels_num = le.transform(labels) + n_classes = len(le.classes_) + print('Training for {} classes.'.format(n_classes)) + + if args.classifier == 'LinearSvm': + clf = SVC(C=1, kernel='linear', probability=True) + elif args.classifier == 'GridSearchSvm': + print(""" + Warning: In our experiences, using a grid search over SVM hyper-parameters only + gives marginally better performance than a linear SVM with C=1 and + is not worth the extra computations of performing a grid search. + """) + param_grid = [ + {'C': [1, 10, 100, 1000], + 'kernel': ['linear']}, + {'C': [1, 10, 100, 1000], + 'gamma': [0.001, 0.0001], + 'kernel': ['rbf']} + ] + clf = GridSearchCV(SVC(C=1, probability=True), param_grid, cv=5) + elif args.classifier == 'GMM': # Doesn't work best + clf = mixture.GaussianMixture(n_components=n_classes) + + # ref: + # http://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#example-classification-plot-classifier-comparison-py + elif args.classifier == 'RadialSvm': # Radial Basis Function kernel + # works better with C = 1 and gamma = 2 + clf = SVC(C=1, kernel='rbf', probability=True, gamma=2) + elif args.classifier == 'DecisionTree': # Doesn't work best + clf = DecisionTreeClassifier(max_depth=20) + elif args.classifier == 'GaussianNB': + clf = GaussianNB() + + # ref: https://jessesw.com/Deep-Learning/ + elif args.classifier == 'DBN': + from nolearn.dbn import DBN + clf = DBN([embeddings.shape[1], 500, labels_num[-1:][0] + 1], # i/p nodes, hidden nodes, o/p nodes + learn_rates=0.3, + # Smaller steps mean a possibly more accurate result, but the + # training will take longer + learn_rate_decays=0.9, + # a factor the initial learning rate will be multiplied by + # after each iteration of the training + epochs=300, # no of iternation + # dropouts = 0.25, # Express the percentage of nodes that + # will be randomly dropped as a decimal. + verbose=1) + + if args.ldaDim > 0: + clf_final = clf + clf = Pipeline([('lda', LDA(n_components=args.ldaDim)), + ('clf', clf_final)]) + + clf.fit(embeddings, labels_num) + + classifier_file = os.path.join(args.workDir, 'classifier.pkl') + print('Saving classifier to "{}"'.format(classifier_file)) + with open(classifier_file, 'wb') as f: + pickle.dump((le, clf), f) + + +def infer(args, multiple=False): + with open(args.classifierModel, 'rb') as f: + if sys.version_info[0] < 3: + (le, clf) = pickle.load(f) + else: + (le, clf) = pickle.load(f, encoding='latin1') + + for img in args.imgs: + print('\n=== {} ==='.format(img)) + reps = get_rep(img, multiple) + if len(reps) > 1: + print('List of faces in image from left to right') + for r in reps: + rep = r[1].reshape(1, -1) + bbx = r[0] + start = time.time() + predictions = clf.predict_proba(rep).ravel() + maxI = np.argmax(predictions) + person = le.inverse_transform([maxI]) + confidence = predictions[maxI] + if args.verbose: + print('Prediction took {} seconds.'.format(time.time() - start)) + if multiple: + print('Predict {} @ x={} with {:.2f} confidence.'.format(str(person[0]), bbx, + confidence)) + else: + print('Predict {} with {:.2f} confidence.'.format(str(person[0]), confidence)) + if isinstance(clf, mixture.GaussianMixture): + dist = np.linalg.norm(rep - clf.means_[maxI]) + print(' + Distance from the mean: {}'.format(dist)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dlibFacePredictor', type=str, help="Path to dlib's face predictor.", + default=os.path.join(DLIB_MODEL_DIR, 'shape_predictor_68_face_landmarks.dat')) + parser.add_argument('--dlibFaceDetectorType', type=str, choices=['HOG', 'CNN'], + help="Type of dlib's face detector to be used.", default='CNN') + parser.add_argument('--dlibFaceDetector', type=str, help="Path to dlib's CNN face detector.", + default=os.path.join(DLIB_MODEL_DIR, 'mmod_human_face_detector.dat')) + parser.add_argument('--upsample', type=int, help="Number of times to upsample images before detection.", default=1) + parser.add_argument('--networkModel', type=str, help='Path to pretrained OpenFace model.', + default=os.path.join(OPENFACE_MODEL_DIR, 'nn4.small2.v1.pt')) + parser.add_argument('--cpu', action='store_true', help='Run OpenFace models on CPU only.') + parser.add_argument('--verbose', action='store_true') + + subparsers = parser.add_subparsers(dest='mode', help='Mode') + trainParser = subparsers.add_parser('train', help='Train a new classifier.') + trainParser.add_argument('--ldaDim', type=int, default=-1) + trainParser.add_argument('--classifier', type=str, + choices=['LinearSvm', + 'GridSearchSvm', + 'GMM', + 'RadialSvm', + 'DecisionTree', + 'GaussianNB', + 'DBN'], + help='The type of classifier to use.', default='LinearSvm') + trainParser.add_argument('workDir', type=str, + help='The input work directory containing "reps.csv" and "labels.csv". Obtained from ' + 'aligning a directory with "align-dlib" and getting the representations with ' + '"batch-represent".') + + inferParser = subparsers.add_parser('infer', help='Predict who an image contains from a trained classifier.') + inferParser.add_argument('classifierModel', type=str, + help='The Python pickle representing the classifier. This is NOT the Torch network ' + 'model, which can be set with --networkModel.') + inferParser.add_argument('imgs', type=str, nargs='+', help='Input image.') + inferParser.add_argument('--multi', help='Infer multiple faces in image', action='store_true') + + args = parser.parse_args() + if args.verbose: + print('Argument parsing and import libraries took {} seconds.'.format( + time.time() - start)) + + if args.mode == 'infer' and args.classifierModel.endswith('.t7'): + raise Exception(""" +Torch network model passed as the classification model, +which should be a Python pickle (.pkl) + +See the documentation for the distinction between the Torch +network and classification models: + + http://cmusatyalab.github.io/openface/demo-3-classifier/ + http://cmusatyalab.github.io/openface/training-new-models/ + +Use `--networkModel` to set a non-standard Torch network model.""") + start = time.time() + + if args.dlibFaceDetectorType == 'CNN': + align = openface.AlignDlib(args.dlibFacePredictor, args.dlibFaceDetector, upsample=args.upsample) + else: + align = openface.AlignDlib(args.dlibFacePredictor, upsample=args.upsample) + model = openface.OpenFaceNet() + if args.cpu: + model.load_state_dict(torch.load(args.networkModel)) + else: + model.load_state_dict(torch.load(args.networkModel, map_location='cuda')) + model.to(torch.device('cuda')) + model.eval() + + if args.verbose: + print('Loading the dlib and OpenFace models took {} seconds.'.format( + time.time() - start)) + start = time.time() + + if args.mode == 'train': + train(args) + elif args.mode == 'infer': + infer(args, args.multi) diff --git a/demos/compare_new.py b/demos/compare_new.py new file mode 100755 index 000000000..600800efc --- /dev/null +++ b/demos/compare_new.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# +# Copyright 2015-2024 Carnegie Mellon University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +start = time.time() + +import argparse +import cv2 +import itertools +import os + +import numpy as np +np.set_printoptions(precision=2) +import torch + +import openface + +PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +MODEL_DIR = os.path.join(PROJECT_DIR, 'models') +DLIB_MODEL_DIR = os.path.join(MODEL_DIR, 'dlib') +OPENFACE_MODEL_DIR = os.path.join(MODEL_DIR, 'openface') +IMG_DIM = 96 + + +def get_rep(img_path): + if args.verbose: + print('Processing {}.'.format(img_path)) + bgr_img = cv2.imread(img_path) + if bgr_img is None: + raise Exception('Unable to load image: {}'.format(img_path)) + rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + + if args.verbose: + print(' + Original size: {}'.format(rgb_img.shape)) + + start = time.time() + bb = align.getLargestFaceBoundingBox(rgb_img) + if bb is None: + raise Exception('Unable to find a face: {}'.format(img_path)) + if args.verbose: + print(' + Face detection took {} seconds.'.format(time.time() - start)) + + start = time.time() + aligned_face = align.align(IMG_DIM, rgb_img, bb, + landmarkIndices=openface.AlignDlib.OUTER_EYES_AND_NOSE) + if aligned_face is None: + raise Exception('Unable to align image: {}'.format(img_path)) + if args.verbose: + print(' + Face alignment took {} seconds.'.format(time.time() - start)) + + start = time.time() + + aligned_face = (aligned_face / 255.).astype(np.float32) + aligned_face = np.expand_dims(np.transpose(aligned_face, (2, 0, 1)), axis=0) # BCHW order + aligned_face = torch.from_numpy(aligned_face) + if not args.cpu: + aligned_face = aligned_face.to(torch.device('cuda')) + + rep = net.forward(aligned_face) + rep = rep.cpu().detach().numpy().squeeze(0) + + if args.verbose: + print(' + OpenFace forward pass took {} seconds.'.format(time.time() - start)) + print('Representation:') + print(rep) + print('-----\n') + return rep + + +def compare(args): + for (img1, img2) in itertools.combinations(args.imgs, 2): + d = get_rep(img1) - get_rep(img2) + print('Comparing {} with {}.'.format(img1, img2)) + print( + ' + Squared l2 distance between representations: {:0.3f}'.format(np.dot(d, d))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('imgs', type=str, nargs='+', help='Input images.') + parser.add_argument('--dlibFacePredictor', type=str, help="Path to dlib's face predictor.", + default=os.path.join(DLIB_MODEL_DIR, 'shape_predictor_68_face_landmarks.dat')) + parser.add_argument('--dlibFaceDetectorType', type=str, choices=['HOG', 'CNN'], + help="Type of dlib's face detector to be used.", default='CNN') + parser.add_argument('--dlibFaceDetector', type=str, help="Path to dlib's CNN face detector.", + default=os.path.join(DLIB_MODEL_DIR, 'mmod_human_face_detector.dat')) + parser.add_argument('--upsample', type=int, help="Number of times to upsample images before detection.", default=1) + parser.add_argument('--networkModel', type=str, help='Path to pretrained OpenFace model.', + default=os.path.join(OPENFACE_MODEL_DIR, 'nn4.small2.v1.pt')) + parser.add_argument('--cpu', action='store_true', help='Run OpenFace models on CPU only.') + parser.add_argument('--verbose', action='store_true') + + args = parser.parse_args() + if args.verbose: + print('Argument parsing and loading libraries took {} seconds.'.format( + time.time() - start)) + + start = time.time() + if args.dlibFaceDetectorType == 'CNN': + align = openface.AlignDlib(args.dlibFacePredictor, args.dlibFaceDetector, upsample=args.upsample) + else: + align = openface.AlignDlib(args.dlibFacePredictor, upsample=args.upsample) + net = openface.OpenFaceNet() + if args.cpu: + net.load_state_dict(torch.load(args.networkModel)) + else: + net.load_state_dict(torch.load(args.networkModel, map_location='cuda')) + net.to(torch.device('cuda')) + net.eval() + + if args.verbose: + print('Loading the dlib and OpenFace models took {} seconds.'.format( + time.time() - start)) + + compare(args) diff --git a/models/get-models.sh b/models/get-models.sh index 249d71435..65136dbae 100755 --- a/models/get-models.sh +++ b/models/get-models.sh @@ -32,16 +32,35 @@ if [ ! -f dlib/shape_predictor_68_face_landmarks.dat ]; then bunzip2 dlib/shape_predictor_68_face_landmarks.dat.bz2 [ $? -eq 0 ] || die "+ Error using bunzip2." fi +if [ ! -f dlib/mmod_human_face_detector.dat ]; then + printf "\n\n====================================================\n" + printf "Downloading dlib's public domain face detector model.\n" + printf "Reference: https://github.com/davisking/dlib-models\n\n" + printf "This will incur about 700KB of network traffic for the compressed\n" + printf "models that will decompress to about 700KB on disk.\n" + printf "====================================================\n\n" + wget -nv \ + http://dlib.net/files/mmod_human_face_detector.dat.bz2 \ + -O dlib/mmod_human_face_detector.dat.bz2 + [ $? -eq 0 ] || die "+ Error in wget." + bunzip2 dlib/mmod_human_face_detector.dat.bz2 + [ $? -eq 0 ] || die "+ Error using bunzip2." +fi mkdir -p openface -if [ ! -f openface/nn4.small2.v1.t7 ]; then +if [ ! -f openface/nn4.small2.v1.pt ]; then printf "\n\n====================================================\n" printf "Downloading OpenFace models, which are copyright\n" printf "Carnegie Mellon University and are licensed under\n" printf "the Apache 2.0 License.\n\n" - printf "This will incur about 100MB of network traffic for the models.\n" + printf "This will incur about 50MB of network traffic for the models.\n" printf "====================================================\n\n" + wget -nv \ + https://storage.cmusatyalab.org/openface-models/nn4.small2.v1.pt \ + -O openface/nn4.small2.v1.pt + [ $? -eq 0 ] || ( rm openface/nn4.small2.v1.pt* && die "+ nn4.small2.v1.pt: Error in wget." ) + wget -nv \ https://storage.cmusatyalab.org/openface-models/nn4.small2.v1.t7 \ -O openface/nn4.small2.v1.t7 @@ -92,6 +111,10 @@ checkmd5 \ dlib/shape_predictor_68_face_landmarks.dat \ 73fde5e05226548677a050913eed4e04 +checkmd5 \ + dlib/mmod_human_face_detector.dat \ + 8d2d36a0ab9adb57f4a866252fd9f047 + checkmd5 \ openface/celeb-classifier.nn4.small2.v1.pkl \ 199a2c0d32fd0f22f14ad2d248280475 @@ -99,3 +122,7 @@ checkmd5 \ checkmd5 \ openface/nn4.small2.v1.t7 \ c95bfd8cc1adf05210e979ff623013b6 + +checkmd5 \ + openface/nn4.small2.v1.pt \ + 8de23b5e35e49df171175d28847c67c4 diff --git a/openface/__init__.py b/openface/__init__.py index 4b7a0fc0b..9d380af7f 100644 --- a/openface/__init__.py +++ b/openface/__init__.py @@ -4,6 +4,7 @@ from .align_dlib import AlignDlib from .torch_neural_net import TorchNeuralNet +from .openfacenet import OpenFaceNet from . import data from . import helper diff --git a/openface/align_dlib.py b/openface/align_dlib.py index f56cfb61e..8525085bd 100644 --- a/openface/align_dlib.py +++ b/openface/align_dlib.py @@ -57,6 +57,7 @@ TPL_MIN, TPL_MAX = np.min(TEMPLATE, axis=0), np.max(TEMPLATE, axis=0) MINMAX_TEMPLATE = (TEMPLATE - TPL_MIN) / (TPL_MAX - TPL_MIN) +CNN_DETECTOR_CONF_THRESHOLD = 0.5 class AlignDlib: @@ -77,17 +78,24 @@ class AlignDlib: INNER_EYES_AND_BOTTOM_LIP = [39, 42, 57] OUTER_EYES_AND_NOSE = [36, 45, 33] - def __init__(self, facePredictor): + def __init__(self, facePredictor, faceDetector=None, upsample=1): """ Instantiate an 'AlignDlib' object. - :param facePredictor: The path to dlib's + :param facePredictor: The path to dlib's face predictor model :type facePredictor: str + :param faceDetector: The path to dlib's CNN face detector model, or None if using HOG detector + :type faceDetector: str """ assert facePredictor is not None - - self.detector = dlib.get_frontal_face_detector() self.predictor = dlib.shape_predictor(facePredictor) + if faceDetector is None: + self.detector_type = 'HOG' + self.detector = dlib.get_frontal_face_detector() + else: + self.detector_type = 'CNN' + self.detector = dlib.cnn_face_detection_model_v1(faceDetector) + self.upsample = upsample def getAllFaceBoundingBoxes(self, rgbImg): """ @@ -101,7 +109,11 @@ def getAllFaceBoundingBoxes(self, rgbImg): assert rgbImg is not None try: - return self.detector(rgbImg, 1) + if self.detector_type == 'HOG': + return self.detector(rgbImg, self.upsample) + elif self.detector_type == 'CNN': + return [mmod_rect.rect for mmod_rect in self.detector(rgbImg, self.upsample) + if mmod_rect.confidence > CNN_DETECTOR_CONF_THRESHOLD] except Exception as e: print("Warning: {}".format(e)) # In rare cases, exceptions are thrown. diff --git a/openface/openfacenet.py b/openface/openfacenet.py new file mode 100644 index 000000000..25c085ebd --- /dev/null +++ b/openface/openfacenet.py @@ -0,0 +1,175 @@ +# Copyright 2015-2024 Carnegie Mellon University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for Pytorch-based face recognition neural network.""" + +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F + + +class Inception(nn.Module): + def __init__(self, inputSize, reduceSize, outputSize, + kernelSize, kernelStride, reduceStride=None, + pool=None, activation=None, batchNorm=True, padding=True): + super().__init__() + self.branches = [] + if reduceStride is None: + reduceStride = [1] * len(reduceSize) + if pool is None: + pool = nn.MaxPool2d((3, 3), stride=(1, 1)) + if activation is None: + activation = nn.ReLU() + + # conv branches + for i in range(len(kernelSize)): + conv_branch = nn.Sequential() + # 1x1 conv + conv_branch.append(nn.Conv2d(inputSize, reduceSize[i], (1, 1), stride=reduceStride[i])) + if batchNorm: + conv_branch.append(nn.BatchNorm2d(reduceSize[i])) + conv_branch.append(activation) + # nxn conv + pad = np.floor_divide(kernelSize[i], 2) if padding else (0, 0) + conv_branch.append(nn.Conv2d(reduceSize[i], outputSize[i], kernelSize[i], + stride=kernelStride[i], padding=pad)) + if batchNorm: + conv_branch.append(nn.BatchNorm2d(outputSize[i])) + conv_branch.append(activation) + self.branches.append(conv_branch) + + # pool branch + pool_branch = nn.Sequential() + # pool + pool_branch.append(pool) + # 1x1 conv + i = len(kernelSize) + if len(reduceSize) > i and reduceSize[i] is not None: + pool_branch.append(nn.Conv2d(inputSize, reduceSize[i], (1, 1), stride=reduceStride[i])) + if batchNorm: + pool_branch.append(nn.BatchNorm2d(reduceSize[i])) + pool_branch.append(activation) + self.branches.append(pool_branch) + + # reduce branch + i = len(kernelSize) + 1 + if len(reduceSize) > i and reduceSize[i] is not None: + reduce_branch = nn.Sequential() + reduce_branch.append(nn.Conv2d(inputSize, reduceSize[i], (1, 1), stride=reduceStride[i])) + if batchNorm: + reduce_branch.append(nn.BatchNorm2d(reduceSize[i])) + reduce_branch.append(activation) + self.branches.append(reduce_branch) + + self.branches = nn.ModuleList(self.branches) + + def forward(self, x: Tensor) -> Tensor: + branch_out = [] + for branch in self.branches: + res = branch(x) + branch_out.append(res) + + # Depth concat with padding + out_height = max(res.shape[2] for res in branch_out) + out_width = max(res.shape[3] for res in branch_out) + for i, res in enumerate(branch_out): + pad_left = int((out_width - res.shape[3]) // 2) + pad_right = out_width - res.shape[3] - pad_left + pad_top = int((out_height - res.shape[2]) // 2) + pad_bottom = out_height - res.shape[2] - pad_top + branch_out[i] = F.pad(res, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0.) + + out = torch.cat(branch_out, dim=1) + return out + + +class OpenFaceNet(nn.Module): + """ + Usage: + model = OpenFaceNet() + + # If load on CPU + model.load_state_dict(torch.load('nn4.small2.v1.pt')) + + # If load on GPU + model.load_state_dict(torch.load('nn4.small2.v1.pt', map_location='cuda:0')) # Pick the right GPU device number + model.to(torch.device('cuda')) + + # Loading model for inference only + model.eval() + """ + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, (7, 7), stride=(2, 2), padding=(3, 3)) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d((3, 3), stride=(2, 2), padding=(1, 1)) + self.lrn = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) + self.conv2 = nn.Conv2d(64, 64, (1, 1)) + self.bn2 = nn.BatchNorm2d(64) + self.conv3 = nn.Conv2d(64, 192, (3, 3), stride=(1, 1), padding=(1, 1)) + self.bn3 = nn.BatchNorm2d(192) + self.incept3a = Inception(inputSize=192, reduceSize=(96, 16, 32, 64), outputSize=(128, 32), + kernelSize=((3, 3), (5, 5)), kernelStride=((1, 1), (1, 1)), + pool=nn.MaxPool2d((3, 3), stride=(2, 2))) + self.incept3b = Inception(inputSize=256, reduceSize=(96, 32, 64, 64), outputSize=(128, 64), + kernelSize=((3, 3), (5, 5)), kernelStride=((1, 1), (1, 1)), + pool=nn.LPPool2d(2, (3, 3), stride=(3, 3))) + self.incept3c = Inception(inputSize=320, reduceSize=(128, 32, None, None), outputSize=(256, 64), + kernelSize=((3, 3), (5, 5)), kernelStride=((2, 2), (2, 2)), + pool=nn.MaxPool2d((3, 3), stride=(2, 2))) + self.incept4a = Inception(inputSize=640, reduceSize=(96, 32, 128, 256), outputSize=(192, 64), + kernelSize=((3, 3), (5, 5)), kernelStride=((1, 1), (1, 1)), + pool=nn.LPPool2d(2, (3, 3), stride=(3, 3))) + self.incept4e = Inception(inputSize=640, reduceSize=(160, 64, None, None), outputSize=(256, 128), + kernelSize=((3, 3), (5, 5)), kernelStride=((2, 2), (2, 2)), + pool=nn.MaxPool2d((3, 3), stride=(2, 2))) + self.incept5a = Inception(inputSize=1024, reduceSize=(96, 96, 256), outputSize=(384,), + kernelSize=((3, 3),), kernelStride=((1, 1),), + pool=nn.LPPool2d(2, (3, 3), stride=(3, 3))) + self.incept5b = Inception(inputSize=736, reduceSize=(96, 96, 256), outputSize=(384,), + kernelSize=((3, 3),), kernelStride=((1, 1),), + pool=nn.MaxPool2d((3, 3), stride=(2, 2))) + self.avgpool = nn.AvgPool2d((3, 3), stride=(1, 1)) + self.ln = nn.Linear(736, 128) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) # Layer 1 + x = self.bn1(x) # Layer 2 + x = self.relu(x) # Layer 3 + x = self.maxpool(x) # Layer 4 + x = self.lrn(x) # Layer 5 + x = self.conv2(x) # Layer 6 + x = self.bn2(x) # Layer 7 + x = self.relu(x) # Layer 8 + x = self.conv3(x) # Layer 9 + x = self.bn3(x) # Layer 10 + x = self.relu(x) # Layer 11 + x = self.lrn(x) # Layer 12 + x = self.maxpool(x) # Layer 13 + x = self.incept3a(x) # Layer 14 + x = self.incept3b(x) # Layer 15 + x = self.incept3c(x) # Layer 16 + x = self.incept4a(x) # Layer 17 + x = self.incept4e(x) # Layer 18 + x = self.incept5a(x) # Layer 19 + # Reshape to (-1, 736, 3, 3) # Layer 20 + x = self.incept5b(x) # Layer 21 + x = self.avgpool(x) # Layer 22 + # Reshape to (-1, 736) # Layer 23 + x = x.view((-1, 736)) # Layer 24 + x = self.ln(x) # Layer 25 + x = F.normalize(x, p=2, dim=1, eps=1e-10) # Layer 26 + return x diff --git a/requirements.txt b/requirements.txt index 93364acba..94b8d507c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,11 @@ -numpy >= 1.1, < 2.0 -scipy >= 0.13, < 0.17 -pandas >= 0.13, < 0.18 -scikit-learn >= 0.17, < 0.18 -nose >= 1.3.1, < 1.4 -nolearn == 0.5b1 +numpy<2 +scipy +pandas +scikit-learn +opencv-python +dlib +torch +torchvision +torchfile +#nose >= 1.3.1, < 1.4 +#nolearn == 0.5b1 diff --git a/setup.py b/setup.py index efebe57ab..3e6cc8be4 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,8 @@ -from distutils.core import setup +from setuptools import setup setup( name='openface', - version='0.2.1', + version='0.3.2', description="Face recognition with Google's FaceNet deep neural network.", url='https://github.com/cmusatyalab/openface', packages=['openface'],