Rendering with different objects in batch
Closed this issue · 7 comments
Hi!
I was wondering how I would go about using the rendering functions (render_texture_batch mostly) when I would like to pass a batch of different objects (meshes) as input.
Pytorch3d is able to achieve this I'm not sure how, but they might be doing some padding, however since nvdiffrast is more barebones, I'm not sure how it would work here..
Thank you!
Yes, Looks like nvdiffrast's range-mode is the answer. Thanks!
I have tried the range mode and it actually only appears to be rasterizing and interpolating the first minibatch index correctly, even though I'm passing the correct ranges tensor to the rasterize function (containing the start index and the counts of the triangles per batch). I saw some interest in this in a different issue so I'm reopening this, however I'm not sure if this is a bug w.r.t nvdiffrast or the documentation is misleading. The changes I made are as follows:
def render_texture_batch_range_mode(
glctx,
mtx,
pos,
pos_clip_ja,
pos_idx,
resolution,
ranges,
uv=None,
uv_idx=None,
tex=None,
vtx_color=None,
return_rast_out=False,
):
if not type(resolution) == list:
resolution = [resolution, resolution]
posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)
tri = pos_idx.int().contiguous()
cts_ranges = ranges.int().contiguous()
assert len(pos_clip_ja.shape) == 2 and len(posw.shape) == 2 and ranges!=None
rast_out, rast_out_db = dr.rasterize(
glctx, pos_clip_ja.contiguous(), tri, resolution=resolution, ranges = cts_ranges
)
# compute the depth
gb_pos, _ = interpolate(posw, rast_out, tri, rast_db=rast_out_db)
shape_keep = gb_pos.shape
gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])
gb_pos = gb_pos[..., :3]
depth = xfm_points(gb_pos.contiguous(), mtx)
depth = depth.reshape(shape_keep)[..., 2] * -1
mask, _ = dr.interpolate(torch.ones(tri.shape).cuda(),
rast_out, tri,rast_db=rast_out_db,diff_attrs="all") #
mask = dr.antialias(mask, rast_out, pos_clip_ja.contiguous(), tri)
# compute vertex color interpolation
if vtx_color is None:
texc, texd = dr.interpolate(
uv, rast_out, uv_idx, rast_db=rast_out_db, diff_attrs="all"
)
color = dr.texture(
tex,
texc,
texd,
filter_mode="linear",
)
color = color * torch.clamp(rast_out[..., -1:], 0, 1) # Mask out background.
else:
color, _ = dr.interpolate(vtx_color, rast_out, tri)
color = color * torch.clamp(rast_out[..., -1:], 0, 1) # Mask out background.
if not return_rast_out:
rast_out = None
return {"rgb": color, "depth": depth, "rast_out": rast_out, 'mask':mask}
Helper function for batching the mesh properties:
def batch_mesh_forward(self,object_names):
props = ['pos','pos_idx','uv','uv_idx']
mesh_result = {}
count = {}
start_index= {}
#meshes = [self.mesh[o]() for o in object_names]
meshes = {o: self.mesh[o]() for o in list(set(object_names))}
for p in props:
count[p] = []
mesh_result[p] = []
for n in object_names:
mesh_result[p].append(meshes[n][p])
count[p].append(meshes[n][p].shape[0])
mesh_result[p] = torch.concat(mesh_result[p],dim=0).cuda()
count[p] = torch.tensor(count[p])
start_index[p] = torch.roll(torch.cumsum(count[p],0),shifts=1,dims=0)
start_index[p][0]=0
mesh_result['tex'] = torch.stack([meshes[n]['tex'] for n in object_names])
mesh_result['ranges'] = torch.vstack((start_index['pos_idx'],count['pos_idx'])).T #shape = num of trianges x 2
return mesh_result,count,start_index
Calling the above function:
mesh_result,count,start_index = self.batch_mesh_forward(object_names)
#cam-view pose to NDC pose
final_mtx_proj = torch.matmul(camera_params['projection_matrix'], obj_pose_T44)
#point cloud is correct
pos_clip_ja = torch.zeros(mesh_result['pos'].shape[0],4).cuda()
for i in range(self.batchsize):
s = start_index['pos'][i]
e = s + count['pos'][i]
pos_clip_ja[s:e] = xfm_points(mesh_result['pos'][s:e].unsqueeze(0).contiguous(),final_mtx_proj[i].unsqueeze(0))[0]
renders = render_texture_batch_range_mode(
glctx=self.glctx,
mtx=obj_pose_T44,
pos=mesh_result['pos'],
pos_clip_ja = pos_clip_ja,
pos_idx=mesh_result['pos_idx'],
uv=mesh_result['uv'],
uv_idx=mesh_result['uv_idx'],
tex = mesh_result['tex'].cuda(),
ranges=mesh_result['ranges'],
resolution=[480,480],
mc_vis=mc_vis
)
I tried running it on some simple data with batch size 4 and the objects at z=-0.3 away from the camera and got these results. Clearly only the first batch is being correctly rendered.
I opened another issue in the nvdiffrast repo with a minimum working example https://github.com/NVlabs/nvdiffrast/issues/164
The problem is solved, please refer: NVlabs/nvdiffrast#164
Are you results correct? Please share them, if this is something useful, I would love to include this into the main repo --- though I do not want to steal your research you might be doing.
I just looked at your other post. Yeah it looked like you needed to offset the positions in the triangle ids. :D
yup, I've provided the code that was modified in the above comment, and the results after the fix (with MWE) in the nvdiffrast issue as well