LeapLabTHU/DAT

Why set the reference point coordinates like this

simplify23 opened this issue · 5 comments

Why set the reference point coordinates like this

    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device)
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key).mul_(2).sub_(1)
        ref[..., 0].div_(H_key).mul_(2).sub_(1)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

i don't understand this one ref[..., 1].div_(W_key).mul_(2).sub_(1) ,
specially why use .mul_(2).sub_(1)?

        if self.offset_range_factor > 0:
            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, "b p h w -> b h w p")
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)  # B H W 2

        if self.no_off:
            offset = offset.fill(0.0)
        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).tanh()

if self.offset_range_factor >= 0:

reference: [0.5, H-0.5] -> [0.5/H, 1-0.5/H] -> [1/H, 2-1/H] -> [1/H-1, 1-1/H]
offset: [-1, 1] -> [-1/H,1/H] -> [-s/H, s/H] (s=self.offset_range_factor)
offset + reference: [-(1-1/H)-s/H, 1-1/H+s/H]

I mean why? why not just [1/H, 1/H]? why offset_range_factor will work better? Maybe some exp?

@simplify23

Update:

        if self.offset_range_factor > 0:
            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, "b p h w -> b h w p")
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)  # B H W 2

        if self.no_off:
            offset = offset.fill(0.0)
        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).tanh()

if self.offset_range_factor >= 0:

  • reference: Get the normalized base coordinate grid.
  • offset: [-1, 1] -> [-1/H,1/H] -> [-s/H, s/H] (s=self.offset_range_factor) Here, the offset is a relative offset from the pixel position and 1/H is the unit offset on the H-axis (up or down). Similarly, 1/W is the unit offset on the W-axis (left or right). Specifying different unit offsets ensures that s has the same units in different axes.
  • offset + reference: This results in sample positions corresponding to different reference positions after the relative offset is applied.

The reference points are set to a uniform grid from 0.5 to -0.5 in size. The .mul_(2).sub_(1) operation normalizes the coordinates into [-1,+1] to meet the protocol of the grid sampling operation F.grid_sample in PyTorch, where (-1, -1) denotes the top-left corner and (+1, +1) denotes the bottom-right corner. The experiments on ablating offset range factor s are in the ablation study sections in the paper.

Kuanch commented

Happen to find this problem when I have the same question, so in the paper
image

And also there are actually ablation studies on page 8.