Classical renderer backward method
Closed this issue · 5 comments
Hi, thank you for the great work!
I was wondering, would you be able to provide the backward() method for the classical renderer?
So then we can backpropagate though the renderer.
Honestly, we do not use backward() method in this work as the disparity map is produced by pretrained monocular depth estimation method in advance. However, we provide the backward() code as below. If you find it useful for your work, please consider citing this paper: "DoF-NeRF: Depth-of-Field Meets Neural Radiance Fields".
#!/user/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import cupy
import re
kernel_Render_updateOutput = '''
extern "C" __global__ void kernel_Render_updateOutput(
const int n,
const float* image, // original image
const float* defocus, // signed defocus map
float* bokehCum, // cumulative bokeh image
float* weightCum // cumulative weight map
)
{
int intRadiusMax = 20; // the maximum blur radius to scatter. It must be fixed as the dynamic one cannot be used for backpropagation.
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(weightCum) / SIZE_2(weightCum) / SIZE_1(weightCum) ) % SIZE_0(weightCum);
// const int intC = ( intIndex / SIZE_3(weightCum) / SIZE_2(weightCum) ) % SIZE_1(weightCum);
const int intY = ( intIndex / SIZE_3(weightCum) ) % SIZE_2(weightCum);
const int intX = ( intIndex ) % SIZE_3(weightCum);
float fltDefocus = VALUE_4(defocus, intN, 0, intY, intX);
float fltRadius = fabsf(fltDefocus);
for (int intDeltaY = -intRadiusMax; intDeltaY <= intRadiusMax; ++intDeltaY) {
for (int intDeltaX = -intRadiusMax; intDeltaX <= intRadiusMax; ++intDeltaX) {
int intNeighborY = intY + intDeltaY;
int intNeighborX = intX + intDeltaX;
if ((intNeighborY >= 0) && (intNeighborY < SIZE_2(weightCum)) && (intNeighborX >= 0) && (intNeighborX < SIZE_3(weightCum))) {
float fltDist = sqrtf((float)(intDeltaY)*(float)(intDeltaY) + (float)(intDeltaX)*(float)(intDeltaX));
float fltWeight = (0.5 + 0.5 * tanhf(4 * (fltRadius - fltDist))) / (fltRadius * fltRadius + 0.2);
atomicAdd(&weightCum[OFFSET_4(weightCum, intN, 0, intNeighborY, intNeighborX)], fltWeight);
atomicAdd(&bokehCum[OFFSET_4(bokehCum, intN, 0, intNeighborY, intNeighborX)], fltWeight * VALUE_4(image, intN, 0, intY, intX));
atomicAdd(&bokehCum[OFFSET_4(bokehCum, intN, 1, intNeighborY, intNeighborX)], fltWeight * VALUE_4(image, intN, 1, intY, intX));
atomicAdd(&bokehCum[OFFSET_4(bokehCum, intN, 2, intNeighborY, intNeighborX)], fltWeight * VALUE_4(image, intN, 2, intY, intX));
}
}
}
}
}
'''
kernel_Render_updateGradImage = '''
extern "C" __global__ void kernel_Render_updateGradImage(
const int n,
const float* image, // original image
const float* defocus, // signed defocus map
const float* gradBokehCum, // gradient of cumulative bokeh image
float* gradImage // gradient of original image
)
{
int intRadiusMax = 20; // the maximum blur radius to scatter. It must be fixed as the dynamic one cannot be used for backpropagation.
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(defocus) / SIZE_2(defocus) / SIZE_1(defocus) ) % SIZE_0(defocus);
// const int intC = ( intIndex / SIZE_3(defocus) / SIZE_2(defocus) ) % SIZE_1(defocus);
const int intY = ( intIndex / SIZE_3(defocus) ) % SIZE_2(defocus);
const int intX = ( intIndex ) % SIZE_3(defocus);
float fltDefocus = VALUE_4(defocus, intN, 0, intY, intX);
float fltRadius = fabsf(fltDefocus);
float fltRadiusSquare = fltRadius * fltRadius + 0.2;
float fltGradImageR = 0.0;
float fltGradImageG = 0.0;
float fltGradImageB = 0.0;
for (int intDeltaY = -intRadiusMax; intDeltaY <= intRadiusMax; ++intDeltaY) {
for (int intDeltaX = -intRadiusMax; intDeltaX <= intRadiusMax; ++intDeltaX) {
int intNeighborY = intY + intDeltaY;
int intNeighborX = intX + intDeltaX;
if ((intNeighborY >= 0) & (intNeighborY < SIZE_2(defocus)) & (intNeighborX >= 0) & (intNeighborX < SIZE_3(defocus))) {
float fltDist = sqrtf((float)(intDeltaY)*(float)(intDeltaY) + (float)(intDeltaX)*(float)(intDeltaX));
float fltTanh = tanhf(4 * (fltRadius - fltDist));
float fltWeight = (0.5 + 0.5 * fltTanh) / fltRadiusSquare;
fltGradImageR += VALUE_4(gradBokehCum, intN, 0, intNeighborY, intNeighborX) * fltWeight;
fltGradImageG += VALUE_4(gradBokehCum, intN, 1, intNeighborY, intNeighborX) * fltWeight;
fltGradImageB += VALUE_4(gradBokehCum, intN, 2, intNeighborY, intNeighborX) * fltWeight;
}
}
}
gradImage[OFFSET_4(gradImage, intN, 0, intY, intX)] = fltGradImageR;
gradImage[OFFSET_4(gradImage, intN, 1, intY, intX)] = fltGradImageG;
gradImage[OFFSET_4(gradImage, intN, 2, intY, intX)] = fltGradImageB;
}
}
'''
kernel_Render_updateGradDefocus = '''
extern "C" __global__ void kernel_Render_updateGradDefocus(
const int n,
const float* image, // original image
const float* defocus, // signed defocus map
const float* gradBokehCum, // gradient of cumulative bokeh image
const float* gradWeightCum, // gradient of cumulative weight map
float* gradDefocus // gradient of signed defocus map
)
{
int intRadiusMax = 20; // the maximum blur radius to scatter. It must be fixed as the dynamic one cannot be used for backpropagation.
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(defocus) / SIZE_2(defocus) / SIZE_1(defocus) ) % SIZE_0(defocus);
// const int intC = ( intIndex / SIZE_3(defocus) / SIZE_2(defocus) ) % SIZE_1(defocus);
const int intY = ( intIndex / SIZE_3(defocus) ) % SIZE_2(defocus);
const int intX = ( intIndex ) % SIZE_3(defocus);
float fltDefocus = VALUE_4(defocus, intN, 0, intY, intX);
float fltRadius = fabsf(fltDefocus);
float fltRadiusSquare = fltRadius * fltRadius + 0.2;
float dRadius_div_dDefocus = 1.0;
if (fltDefocus < 0) {
dRadius_div_dDefocus = -1.0;
}
float fltGradDefocus = 0.0;
for (int intDeltaY = -intRadiusMax; intDeltaY <= intRadiusMax; ++intDeltaY) {
for (int intDeltaX = -intRadiusMax; intDeltaX <= intRadiusMax; ++intDeltaX) {
int intNeighborY = intY + intDeltaY;
int intNeighborX = intX + intDeltaX;
if ((intNeighborY >= 0) & (intNeighborY < SIZE_2(defocus)) & (intNeighborX >= 0) & (intNeighborX < SIZE_3(defocus))) {
float fltDist = sqrtf((float)(intDeltaY)*(float)(intDeltaY) + (float)(intDeltaX)*(float)(intDeltaX));
float fltTanh = tanhf(4 * (fltRadius - fltDist));
float dWeight_div_dDefocus = dRadius_div_dDefocus * (2 * (1 - fltTanh * fltTanh) / fltRadiusSquare - (1 + fltTanh) * fltRadius / fltRadiusSquare / fltRadiusSquare);
float fltWeight = (0.5 + 0.5 * fltTanh) / fltRadiusSquare;
fltGradDefocus += VALUE_4(gradBokehCum, intN, 0, intNeighborY, intNeighborX) * VALUE_4(image, intN, 0, intY, intX) * dWeight_div_dDefocus;
fltGradDefocus += VALUE_4(gradBokehCum, intN, 1, intNeighborY, intNeighborX) * VALUE_4(image, intN, 1, intY, intX) * dWeight_div_dDefocus;
fltGradDefocus += VALUE_4(gradBokehCum, intN, 2, intNeighborY, intNeighborX) * VALUE_4(image, intN, 2, intY, intX) * dWeight_div_dDefocus;
fltGradDefocus += VALUE_4(gradWeightCum, intN, 0, intNeighborY, intNeighborX) * dWeight_div_dDefocus;
}
}
}
gradDefocus[OFFSET_4(gradDefocus, intN, 0, intY, intX)] = fltGradDefocus;
}
}
'''
def cupy_kernel(strFunction, objVariables):
strKernel = globals()[strFunction]
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
if objMatch is None:
break
# end
intArgs = int(objMatch.group(2))
strArgs = objMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
intStrides[intArg]) + ')' for intArg in range(intArgs)]
strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
if objMatch is None:
break
# end
intArgs = int(objMatch.group(2))
strArgs = objMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
intStrides[intArg]) + ')' for intArg in range(intArgs)]
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
# end
return strKernel
# end
# @cupy.util.memoize(for_each_device=True)
@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
# return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
return cupy.RawModule(code=strKernel).get_function(strFunction)
# end
class _FunctionRender(torch.autograd.Function):
@staticmethod
def forward(self, image, defocus):
self.save_for_backward(image, defocus)
bokeh_cum = torch.zeros_like(image)
weight_cum = torch.zeros_like(defocus)
if defocus.is_cuda is True:
n = weight_cum.nelement()
cupy_launch('kernel_Render_updateOutput', cupy_kernel('kernel_Render_updateOutput', {
'image': image,
'defocus': defocus,
'bokehCum': bokeh_cum,
'weightCum': weight_cum
}))(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cupy.int(n),
image.data_ptr(),
defocus.data_ptr(),
bokeh_cum.data_ptr(),
weight_cum.data_ptr()
]
)
# end
elif defocus.is_cuda is False:
raise NotImplementedError()
# end
return bokeh_cum, weight_cum
# end
@staticmethod
def backward(self, grad_bokeh_cum, grad_weight_cum):
image, defocus = self.saved_tensors
grad_image = torch.zeros_like(image) if self.needs_input_grad[0] is True else None
grad_defocus = torch.zeros_like(defocus) if self.needs_input_grad[1] is True else None
assert grad_bokeh_cum.is_cuda is True and grad_weight_cum.is_cuda is True
if grad_image is not None:
n = defocus.nelement()
cupy_launch('kernel_Render_updateGradImage', cupy_kernel('kernel_Render_updateGradImage', {
'image': image,
'defocus': defocus,
'gradBokehCum': grad_bokeh_cum,
'gradImage': grad_image,
}))(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cupy.int(n),
image.data_ptr(),
defocus.data_ptr(),
grad_bokeh_cum.data_ptr(),
grad_image.data_ptr(),
]
)
# end
if grad_defocus is not None:
n = defocus.nelement()
cupy_launch('kernel_Render_updateGradDefocus', cupy_kernel('kernel_Render_updateGradDefocus', {
'image': image,
'defocus': defocus,
'gradBokehCum': grad_bokeh_cum,
'gradWeightCum': grad_weight_cum,
'gradDefocus': grad_defocus,
}))(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cupy.int(n),
image.data_ptr(),
defocus.data_ptr(),
grad_bokeh_cum.data_ptr(),
grad_weight_cum.data_ptr(),
grad_defocus.data_ptr(),
]
)
# end
return grad_image, grad_defocus
# end
def FunctionRender(image, defocus):
bokeh_cum, weight_cum = _FunctionRender.apply(image, defocus)
return bokeh_cum, weight_cum
# end
class ModuleRenderScatter(torch.nn.Module):
def __init__(self):
super(ModuleRenderScatter, self).__init__()
# end
def forward(self, image, defocus):
bokeh_cum, weight_cum = FunctionRender(image, defocus)
bokeh = bokeh_cum / weight_cum
return bokeh
# end
# end
if __name__ == '__main__':
module = ModuleRenderScatter().cuda()
image = torch.rand(1, 3, 200, 200).cuda().requires_grad_(True)
defocus = torch.rand(1, 1, 200, 200).cuda().requires_grad_(True)
gt = torch.zeros(1, 3, 200, 200).cuda()
optimizer = torch.optim.Adam(params=[image, defocus], lr=1e-2, eps=1e-8)
for i in range(100):
optimizer.zero_grad()
pred = module(image, defocus)
loss = (pred - gt).abs().mean()
loss.backward()
optimizer.step()
print(loss)
Wow many thanks for this!
Do you have a link to the paper "DoF-NeRF: Depth-of-Field Meets Neural Radiance Fields" ?
I can't seem to find it.
https://github.com/zijinwuzijin/DoF-NeRF
The paper was accepted by ACMMM2022 and will be public soon.
Thank you!
Wow many thanks for this! Do you have a link to the paper "DoF-NeRF: Depth-of-Field Meets Neural Radiance Fields" ? I can't seem to find it.
The full paper of DoF-NeRF is now available at Arxiv.
https://arxiv.org/abs/2208.00945
We will release its code and data soon.