Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mps device option is selected if you run this repository in MacOS #300

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ Using gpu acceleration (requires proper gpu drivers for docker):
docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4
```

### Run in MAC

Before running the inference, you need to set the following environment variable:

````
export PYTORCH_ENABLE_MPS_FALLBACK=1
````

## Evaluation
Download [RIFE model](https://drive.google.com/file/d/1h42aGYPNJn2q8j_GVkS_yDu__G_UZ2GX/view?usp=sharing) or [RIFE_m model](https://drive.google.com/file/d/147XVsDXBfJPlyct2jfo9kpbL944mNeZr/view?usp=sharing) reported by our paper.

Expand Down
8 changes: 7 additions & 1 deletion benchmark/ATD12K.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from torch.nn import functional as F
from model.pytorch_msssim import ssim_matlab
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = Model()
model.load_model('train_log')
Expand Down
6 changes: 5 additions & 1 deletion benchmark/HD.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from model.RIFE import Model
from skimage.color import rgb2yuv, yuv2rgb
from yuv_frame_io import YUV_Read,YUV_Write
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model()
model.load_model('train_log')
Expand Down
6 changes: 5 additions & 1 deletion benchmark/HD_multi_4X.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from model.RIFE import Model
from skimage.color import rgb2yuv, yuv2rgb
from yuv_frame_io import YUV_Read,YUV_Write
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model(arbitrary=True)
model.load_model('RIFE_m_train_log')
Expand Down
6 changes: 5 additions & 1 deletion benchmark/MiddleBury_Other.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from torch.nn import functional as F
from model.pytorch_msssim import ssim_matlab
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model()
model.load_model('train_log')
Expand Down
6 changes: 5 additions & 1 deletion benchmark/UCF101.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from torch.nn import functional as F
from model.pytorch_msssim import ssim_matlab
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model()
model.load_model('train_log')
Expand Down
6 changes: 5 additions & 1 deletion benchmark/Vimeo90K.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from torch.nn import functional as F
from model.pytorch_msssim import ssim_matlab
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model()
model.load_model('train_log')
Expand Down
6 changes: 5 additions & 1 deletion benchmark/testtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

model = Model()
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
Expand Down
9 changes: 8 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
import random
from torch.utils.data import DataLoader, Dataset


cv2.setNumThreads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class VimeoDataset(Dataset):
def __init__(self, dataset_name, batch_size=32):
self.batch_size = batch_size
Expand Down
11 changes: 8 additions & 3 deletions inference_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
if torch.backends.mps.is_available():
device = torch.device("mps")
os.environment["PYTORCH_ENABLE_MPS_FALLBACK"] = 1
elif torch.cuda.is_available():
device = torch.device("cuda")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
else:
device = torch.device("cpu")
torch.set_grad_enabled(False)

parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--img', dest='img', nargs=2, required=True)
Expand Down
23 changes: 17 additions & 6 deletions inference_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from queue import Queue, Empty
from model.pytorch_msssim import ssim_matlab

np_version = [int(i) for i in np.__version__.split('.')]
if np_version[0] == 2 or (np_version[0] == 1 and np_version[1] >= 20):
np.float = float
np.int = int

warnings.filterwarnings("ignore")

def transferAudio(sourceVideo, targetVideo):
Expand Down Expand Up @@ -77,13 +82,21 @@ def transferAudio(sourceVideo, targetVideo):
if not args.img is None:
args.png = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():

if torch.backends.mps.is_available():
device = torch.device("mps")
if(args.fp16):
torch.set_default_tensor_type(torch.HalfTensor)
elif torch.cuda.is_available():
device = torch.device("cuda")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if(args.fp16):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
device = torch.device("cpu")
torch.set_grad_enabled(False)


try:
try:
Expand Down Expand Up @@ -115,11 +128,9 @@ def transferAudio(sourceVideo, targetVideo):
fps = videoCapture.get(cv2.CAP_PROP_FPS)
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
videoCapture.release()
fpsNotAssigned = True
if args.fps is None:
fpsNotAssigned = True
args.fps = fps * (2 ** args.exp)
else:
fpsNotAssigned = False
videogen = skvideo.io.vreader(args.video)
lastframe = next(videogen)
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
Expand Down
12 changes: 8 additions & 4 deletions model/RIFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from model.laplacian import *
from model.refine import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Model:
def __init__(self, local_rank=-1, arbitrary=False):
def __init__(self, local_rank=-1, mps=False, arbitrary=False):
if arbitrary == True:
self.flownet = IFNet_m()
else:
Expand All @@ -26,7 +30,7 @@ def __init__(self, local_rank=-1, arbitrary=False):
self.epe = EPE()
self.lap = LapLoss()
self.sobel = SOBEL()
if local_rank != -1:
if local_rank != -1 and mps is False:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)

def train(self):
Expand All @@ -47,7 +51,7 @@ def convert(param):
}

if rank <= 0:
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')))

def save_model(self, path, rank=0):
if rank == 0:
Expand Down
7 changes: 5 additions & 2 deletions model/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import torch
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def gauss_kernel(size=5, channels=3):
kernel = torch.tensor([[1., 4., 6., 4., 1],
Expand Down
6 changes: 5 additions & 1 deletion model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch.nn.functional as F
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class EPE(nn.Module):
Expand Down
6 changes: 5 additions & 1 deletion model/oldmodel/IFNet_HD.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from model.warplayer import warp


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
Expand Down
7 changes: 6 additions & 1 deletion model/oldmodel/IFNet_HDv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from model.warplayer import warp


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
Expand Down
6 changes: 5 additions & 1 deletion model/oldmodel/RIFE_HD.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import torch.nn.functional as F
from model.loss import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
Expand Down
6 changes: 5 additions & 1 deletion model/oldmodel/RIFE_HDv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import torch.nn.functional as F
from model.loss import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
Expand Down
23 changes: 15 additions & 8 deletions model/pytorch_msssim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from math import exp
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
Expand Down Expand Up @@ -96,24 +101,26 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full

padd = 0
(_, _, height, width) = img1.size()

img1 = torch.tensor(img1.unsqueeze(1), device='cpu')
img2 = torch.tensor(img2.unsqueeze(1), device='cpu')

if window is None:
real_size = min(window_size, height, width)
window = create_window_3d(real_size, channel=1).to(img1.device)
# Channel is set to 1 since we consider color images as volumetric images

img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)

mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
mu1 = torch.tensor(F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1), device=device)
mu2 = torch.tensor(F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1), device=device)

mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2

sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
sigma1_sq = torch.tensor(F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1), device=device) - mu1_sq
sigma2_sq = torch.tensor(F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1), device=device) - mu2_sq
sigma12 = torch.tensor(F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1), device=device) - mu1_mu2

C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
Expand Down
6 changes: 5 additions & 1 deletion model/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from model.warplayer import warp
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
Expand Down
6 changes: 5 additions & 1 deletion model/refine_2R.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
Expand Down
7 changes: 6 additions & 1 deletion model/warplayer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backwarp_tenGrid = {}


Expand Down
Loading