Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS support #630

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _linear_to_mel(spectogram):

def _build_mel_basis():
assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
fmin=hp.fmin, fmax=hp.fmax)

def _amp_to_db(x):
Expand Down
4 changes: 2 additions & 2 deletions face_detection/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __int__(self):
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
Expand Down Expand Up @@ -76,4 +76,4 @@ def get_detections_for_batch(self, images):
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))

return results
return results
2 changes: 1 addition & 1 deletion face_detection/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, device, verbose):
logger = logging.getLogger(__name__)
logger.warning("Detection running on CPU, this may be potentially slow.")

if 'cpu' not in device and 'cuda' not in device:
if 'cpu' not in device and 'cuda' not in device and 'mps' not in device:
if verbose:
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
raise ValueError
Expand Down
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def datagen(frames, mels):
yield img_batch, mel_batch, frame_batch, coords_batch

mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print('Using {} for inference.'.format(device))

def _load(checkpoint_path):
Expand Down
4 changes: 2 additions & 2 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def process_video_file(vfile, args, gpu_id):
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))

for j, f in enumerate(preds):
i += 1
if f is None:
continue
i += 1

x1, y1, x2, y2 = f
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
Expand Down Expand Up @@ -110,4 +110,4 @@ def main(args):
continue

if __name__ == '__main__':
main(args)
main(args)
16 changes: 8 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
librosa==0.7.0
numpy==1.17.1
opencv-contrib-python>=4.2.0.34
opencv-python==4.1.0.25
torch==1.1.0
torchvision==0.3.0
tqdm==4.45.0
numba==0.48
librosa
numpy
opencv-contrib-python
opencv-python
torch
torchvision
tqdm
numba