Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CPU version #27

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions unit_tests/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
```
pip install pytest
pytest <test_feature>.py

```
28 changes: 28 additions & 0 deletions unit_tests/test_splat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import sys
sys.path.insert(0, "..")
import torch
from utils.splat2d_cpu import splat2d
# import pytest


def test_splat_cpu():

rs = 64 # resolution
blank_img = torch.zeros([1, 3, 512, 512],device='cpu')
points = torch.rand([1, 943, 2],dtype = torch.float32,device='cpu')
colors = torch.rand([1, 943, 3],dtype = torch.float32,device='cpu')
sigma = torch.tensor([0.3000], device='cpu')
prop_obj_img = splat2d(blank_img, points, colors, sigma, False) # (N, C, H, W)
# import pdb;pdb.set_trace()

assert prop_obj_img!=blank_img

# assert im.shape == (1,3,512,512)
# #torch.save({"image":im, "z":z}, "fixtures/stylegan_ref.pth")
# #import pdb; pdb.set_trace()
# assert torch.allclose(im, ref["image"], atol=1e-4)


if __name__ == '__main__':
test_splat_cpu()
1 change: 1 addition & 0 deletions utils/splat2d_cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .splat import *
36 changes: 36 additions & 0 deletions utils/splat2d_cpu/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch.autograd as ag
from .splat_cpu import splat_forward_cpu
__all__ = ['splat2d']


_splat = None


class Splat2DFunction(ag.Function):

@staticmethod
def forward(ctx, input, coordinates, values, sigma, soft_normalize=False):

assert 'FloatTensor' in coordinates.type() and 'FloatTensor' in values.type(), \
'Splat2D only takes float coordinates and values, got {} and {} instead.'.format(coordinates.type(), values.type())
assert coordinates.size(0) == values.size(0) and coordinates.size(1) == values.size(1), \
'coordinates should be size (N, num_points, 2) and values should be size (N, num_points, *), got {} and {} instead.'.format(coordinates.shape, values.shape)
assert input.size(0) == coordinates.size(0) and input.dim() == 4, 'input should be of size (N, *, H, W), got {} instead'.format(input.shape)
assert sigma.size(0) == input.size(0), 'sigma should be a tensor of size (N,)'

input = input.contiguous()
coordinates = coordinates.contiguous()
values = values.contiguous()
sigma = sigma.contiguous()

# Apply splatting
output = splat_forward_cpu(input, coordinates, values, sigma, soft_normalize)

return output

@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError


splat2d = Splat2DFunction.apply
16 changes: 16 additions & 0 deletions utils/splat2d_cpu/splat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch.nn as nn

from .functional import splat2d

__all__ = ['Splat2D', 'splat2d']


class Splat2D(nn.Module):
def __init__(self):
super().__init__()

def forward(self, coordinates, values, sigma, height, width):
return splat2d(coordinates, values, sigma, height, width)

def extra_repr(self):
return ''
123 changes: 123 additions & 0 deletions utils/splat2d_cpu/splat_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
'''/*
* File : splat_cpu.py
* Author : Bill Peebles
* Email : peebles@berkeley.edu
* Modify to python : Rebecca Li
* Email : xiaoli@adobe.com
*/
'''
import torch
import numpy as np

def GaussianPDF( mu_1:float, mu_2, x_1, x_2, normalizer):
'''
Calculate the gausian normalized pixel value
'''
return torch.exp(normalizer * (torch.pow(x_1 - mu_1, 2.0) + torch.pow(x_2 - mu_2, 2.0)))

def SplatForward(
bottom_coordinates,
bottom_values,
bottom_sigma: float,
top_alpha_splats: float,
top_output: float,
num_points : int,
channels : int,
height : int,
width : int):

'''
Modify from https://github.com/wpeebles/gangealing/blob/main/utils/splat2d_cuda/src/splat_gpu_impl.cu
# self is a loop over batch and point in coordinates

Args:
bottom_coordinates: (n, num_points, 2), (x,y)-coordinates
bottom_values: (n, num_points, channels)
bottom_sigma: float
top_alpha_splats: (n, height,width)
top_output: (N, C, H, W)
Output:
top_alpha_splats: updated value.
top_output: updated value. (N, C, H, W) an element in the output

# pw = index % width
# ph = (index / width) % height
# c = (index / width / height) % channels

[TODO] use CPU multithread to speed up this function
reference: https://docs.python.org/3/library/multiprocessing.html

- Test the function:

.. code-block::

unit_tests/test_splat.py
'''
for n in range( int(bottom_coordinates[0])):
# n for batch id
for i in range(num_points):
# i for point id
stdev = bottom_sigma[n]
length = 2 * stdev
x_coord = bottom_coordinates[n][i][0]
y_coord = bottom_coordinates[n][i][1]
normalizer = - torch.pow(2 * stdev * stdev, -1.0)

# Ignore out-of-bounds points:
if (x_coord >= 0 and x_coord < width) and (y_coord >= 0 and y_coord < height):

# import pdb;pdb.set_trace()
t = int(torch.fmax(torch.tensor(0), torch.floor(y_coord - length)))
b = int(torch.fmin(torch.tensor(height - 1), torch.ceil(y_coord + length)))
l = int(torch.fmax(torch.tensor(0), torch.floor(x_coord - length)))
r = int(torch.fmin(torch.tensor(width - 1), torch.ceil(x_coord + length)))

for lh in range ( t,b+1):
for lw in range (l, r+1 ) :
alpha = GaussianPDF(x_coord, y_coord, float(lw), float(lh), normalizer)
current_alpha_splat = top_alpha_splats[n][lh][lw]
top_alpha_splats[n][lh][lw] = current_alpha_splat + alpha
for c in range (channels):
current_output = top_output[n][c][lh][lw]
top_output[n][c][lh][lw] = current_output + alpha * bottom_values[n][i][c]

return top_alpha_splats, top_output


def splat_forward_cpu( input, coordinates, values,
sigma, soft_normalize= True) :

nr_imgs = input.size(0)
nr_points = coordinates.size(1)
nr_channels = input.size(1)
top_count = nr_imgs * nr_points
height = input.size(2)
width = input.size(3)
alpha_splats = torch.zeros([nr_imgs, height, width], device=values.device)
output = input.clone()

alpha_splats, output = SplatForward(
bottom_coordinates = coordinates ,
bottom_values =values ,
bottom_sigma = sigma,
top_alpha_splats = alpha_splats,
top_output = output,
num_points = nr_points,
channels = nr_channels,
height = height ,
width = width,

)

alpha_splats = alpha_splats.view(nr_imgs, 1, height, width)

if soft_normalize:
alpha_splats = alpha_splats.clamp(1.0)

output = output / (alpha_splats + 1e-8)

return output




4 changes: 3 additions & 1 deletion utils/vis_tools/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import moviepy.editor
import plotly.graph_objects as go
import plotly.colors
from utils.splat2d_cuda import splat2d
# from utils.splat2d_cuda import splat2d
from utils.splat2d_cpu import splat2d

from utils.laplacian_blending import LaplacianBlender
from tqdm import tqdm
import ray
Expand Down