brentyi/tilted

Visualizing method

chiehwangs opened this issue · 1 comments

Hi Brent,

Firstly, thanks for sharing this excellent and robust work !

Is it possible to share the method of visualizing "the structure-revealing L2-norm of interpolated features" in figure 8 ? I am trying to draw inspiration from your work.

best regards,
chieh

Hi @chiehwangs,

Thanks for your nice words!

For the code that generates these specific heatmaps, you can search the repo for "transform_feature_norms". What's happening here is: we compute the feature norm corresponding to each transform for each sample along the ray, and then alpha composite these norms (with the same weights we would use for compositing RGB). The final visualization is the feature norms a single transformation, which we pick by argmax-ing the norm itself.

The norms are first computed on a per-transform basis in the rendering pipeline. Note that the output shape here is (# rays, # transforms):

component_norms = functools.reduce(
jnp.add,
map(
# Each component has shape (groups, rays, samples, channels)
lambda a: reduce(
a**2,
"(g transform_count) rays samples channels -> rays samples transform_count",
reduction="sum",
transform_count=transform_count,
rays=num_rays,
),
primary_components,
),
)
assert (
component_norms.shape
== probs.p_terminates.shape + (transform_count,)
== rgb.shape[:-1] + (transform_count,)
)
transform_feature_norm = einsum(
component_norms,
probs.p_terminates,
"rays samples transform_count, rays samples -> rays transform_count",
)

These norms are then scaled to [0, 1]. We index into only one transform at a time (# rays, # transforms) => (#rays,)), then apply a colormap to convert to RGB for the final visualization:

tilted/visualize_nerf.py

Lines 188 to 199 in df4614a

if self.mode == "transform_feature_norm":
image = image - image.min()
image /= image.max()
image = image[
...,
onp.argsort(
-onp.linalg.norm(
image.reshape((-1, image.shape[-1])), axis=0
)
)[self.transform_viz_index % image.shape[-1]],
]
image = (mpl.colormaps[self.cmap](image) * 255.0).astype(onp.uint8)

Hope that's helpful, and please let me know if anything's unclear!