Skip to content

Commit

Permalink
Merge pull request #32 from uci-rendering/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
GuangyanCai authored Sep 5, 2023
2 parents 319162a + c96b22b commit bed597b
Show file tree
Hide file tree
Showing 8 changed files with 1,629 additions and 17 deletions.
437 changes: 437 additions & 0 deletions examples/2_forward/code.ipynb

Large diffs are not rendered by default.

111 changes: 101 additions & 10 deletions irtk/connectors/psdr_jit_connector.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
from ..connector import Connector
from ..scene import *
from ..config import *
from ..io import write_mesh
from collections import OrderedDict

import drjit
import psdr_jit
from drjit.scalar import Array3f
from drjit.cuda import Array3f as Vector3fC, Array3i as Vector3iC
from drjit.cuda.ad import Array3f as Vector3fD, Float32 as FloatD, Matrix4f as Matrix4fD, Matrix3f as Matrix3fD
from drjit.cuda.ad import Array3f as Vector3fD, Array1f as Vector1fD, Float32 as FloatD, Matrix4f as Matrix4fD, Matrix3f as Matrix3fD
from drjit.cuda.ad import Float32 as FloatD
import torch

import os

class PSDRJITConnector(Connector, connector_name='psdr_jit'):

backend = 'torch'
device = 'cuda'
ftype = torch.float32
itype = torch.long

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -73,7 +67,7 @@ def renderC(self, scene, render_options, sensor_ids=[0], integrator_id=0):

images = []
for sensor_id in sensor_ids:
image = torch.zeros((h * w, c)).to(device).to(ftype)
image = to_torch_f(torch.zeros((h * w, c)))
for i in range(npass):
image_pass = integrator.renderC(cache['scene'], sensor_id).torch()
image += image_pass / npass
Expand Down Expand Up @@ -106,11 +100,48 @@ def renderD(self, image_grads, scene, render_options, sensor_ids=[0], integrator
drjit.backward(tmp)

for param_grad, drjit_param in zip(param_grads, drjit_params):
grad = drjit.grad(drjit_param).torch().to(device).to(ftype)
grad = to_torch_f(drjit.grad(drjit_param).torch())
grad = torch.nan_to_num(grad).reshape(param_grad.shape)
param_grad += grad

return param_grads
return param_grads

def forward_ad_mesh_translation(self, mesh_id, scene, render_options, sensor_ids=[0], integrator_id=0):
cache, drjit_params = self.update_scene_objects(scene, render_options)
assert len(drjit_params) == 0

P = FloatD(0.)
drjit.enable_grad(P)
psdr_mesh = cache['scene'].param_map[cache['name_map'][mesh_id]]
psdr_mesh.set_transform(Matrix4fD([[1.,0.,0.,P],[0.,1.,0.,0.],[0.,0.,1.,0.],[0.,0.,0.,1.],]))

cache['scene'].configure(sensor_ids)

npass = render_options['npass']
h, w, c = cache['film']['shape']
if type(integrator_id) == int:
integrator = list(cache['integrators'].values())[integrator_id]
elif type(integrator_id) == str:
integrator = cache['integrators'][integrator_id]
else:
raise RuntimeError('integrator_id is invalid: {integrator_id}')


image = to_torch_f(torch.zeros((h * w, c)))
grad_image = to_torch_f(torch.zeros((h * w, c)))

for j in range(npass):
drjit_image = integrator.renderD(cache['scene'], sensor_ids[0])
image += to_torch_f(drjit_image.torch()) / npass

drjit.set_grad(P, 1.0)
drjit.forward_to(drjit_image)
drjit_grad_image = drjit.grad(drjit_image)
grad_image += to_torch_f(drjit_grad_image.torch()) / npass

image = image.reshape(h, w, c)
grad_image = grad_image.reshape(h, w, c)
return image, grad_image

@PSDRJITConnector.register(Integrator)
def process_integrator(name, scene):
Expand Down Expand Up @@ -417,4 +448,64 @@ def enable_grad(drjit_param):
if param_name == 'radiance':
enable_grad(psdr_emitter.radiance.data)

return drjit_params


# Scene components specfic to psdr-jit
class MicrofacetBRDFPerVertex(ParamGroup):

def __init__(self, d, s, r):
super().__init__()

self.add_param('d', to_torch_f(d), is_tensor=True, is_diff=True, help_msg='diffuse reflectance')
self.add_param('s', to_torch_f(s), is_tensor=True, is_diff=True, help_msg='specular reflectance')
self.add_param('r', to_torch_f(r), is_tensor=True, is_diff=True, help_msg='roughness')

@PSDRJITConnector.register(MicrofacetBRDFPerVertex)
def process_microfacet_brdf_per_vertex(name, scene):
brdf = scene[name]
cache = scene.cached['psdr_jit']
psdr_scene = cache['scene']

# Create the object if it has not been created
if name not in cache['name_map']:
d = Vector3fD(brdf['d'])
s = Vector3fD(brdf['s'])
r = Vector1fD(brdf['r'])

psdr_bsdf = psdr_jit.MicrofacetBSDFPerVertex(s, d, r)
psdr_scene.add_BSDF(psdr_bsdf, name)
cache['name_map'][name] = f"BSDF[id={name}]"

psdr_brdf = psdr_scene.param_map[cache['name_map'][name]]

# Update parameters
updated = brdf.get_updated()
if len(updated) > 0:
for param_name in updated:
if param_name == 'd':
psdr_brdf.diffuseReflectance = Vector3fD(brdf['d'])
elif param_name == 's':
psdr_brdf.specularReflectance = Vector3fD(brdf['s'])
elif param_name == 'r':
psdr_brdf.roughness= Vector1fD(brdf['r'])
brdf.params[param_name]['updated'] = False

# Enable grad for parameters requiring grad
drjit_params = []

def enable_grad(drjit_param):
drjit.enable_grad(drjit_param)
drjit_params.append(drjit_param)

requiring_grad = brdf.get_requiring_grad()
if len(requiring_grad) > 0:
for param_name in requiring_grad:
if param_name == 'd':
enable_grad(psdr_brdf.diffuseReflectance)
elif param_name == 's':
enable_grad(psdr_brdf.specularReflectance)
elif param_name == 'r':
enable_grad(psdr_brdf.roughness)

return drjit_params
Loading

0 comments on commit bed597b

Please sign in to comment.