lilanxiao/Rotated_IoU

box_intersection_2d 中的box1_in_box2存在bug, box1_in_box2(box1,box1), 在特定数据下返回的不是[true,true,true,true]

XiaoyanQian opened this issue · 1 comments

我尝试在维基百科上查找了另外一种算法(叉积), 重写了此函数,是我对函数的理解有错误?或者你提供的原始的函数存在bug?
你有什么修正的建议吗?

  • 原始代码
def box1_in_box2(corners1:torch.Tensor, corners2:torch.Tensor):
    a = corners2[:, :, 0:1, :]  # (B, N, 1, 2)
    b = corners2[:, :, 1:2, :]  # (B, N, 1, 2)
    d = corners2[:, :, 3:4, :]  # (B, N, 1, 2)
    ab = b - a                  # (B, N, 1, 2)
    am = corners1 - a           # (B, N, 4, 2)
    ad = d - a                  # (B, N, 1, 2)
    p_ab = torch.sum(ab * am, dim=-1)       # (B, N, 4)
    norm_ab = torch.sum(ab * ab, dim=-1)    # (B, N, 1)
    p_ad = torch.sum(ad * am, dim=-1)       # (B, N, 4)
    norm_ad = torch.sum(ad * ad, dim=-1)    # (B, N, 1)
    # NOTE: the expression looks ugly but is stable if the two boxes are exactly the same
    # also stable with different scale of bboxes
    cond1 = (p_ab / norm_ab > - 1e-6) * (p_ab / norm_ab < 1 + 1e-6)   # (B, N, 4)
    cond2 = (p_ad / norm_ad > - 1e-6) * (p_ad / norm_ad < 1 + 1e-6)   # (B, N, 4)
    return cond1*cond2
  • 我修改后的代码
def box1_in_box2(corners1:torch.Tensor, corners2:torch.Tensor):
    p1 = corners2[:, :, 0:1, :]  # (B, N, 1, 2)
    p2 = corners2[:, :, 1:2, :]  # (B, N, 1, 2)
    p3 = corners2[:, :, 2:3, :]  # (B, N, 1, 2)
    p4 = corners2[:, :, 3:4, :]  # (B, N, 1, 2)
    p=corners1

    # p1 p2 * p3 p4 p
    x_in=\
        ((p2[...,0]-p1[...,0])*(p[...,1]-p1[...,1])-(p[...,0]-p1[...,0])*(p2[...,1]-p1[...,1])) * \
        ((p4[...,0]-p3[...,0])*(p[...,1]-p3[...,1])-(p[...,0]-p3[...,0])*(p4[...,1]-p3[...,1]))

    # p2 p3 * p4 p1 p
    y_in=\
        ((p3[...,0]-p2[...,0])*(p[...,1]-p2[...,1])-(p[...,0]-p2[...,0])*(p3[...,1]-p2[...,1])) * \
        ((p1[...,0]-p4[...,0])*(p[...,1]-p4[...,1])-(p[...,0]-p4[...,0])*(p1[...,1]-p4[...,1]))

    cond1= (x_in>=0)
    cond2= (y_in>=0)
    value= cond1*cond2
    return value

感谢你的issue。我觉得是数值误差引起的问题。数学上应该是等效的。
说实话我这整个repo的代码都是naive实现,只是proof of concept,写的时候突出一个quick and dirty,没有针对数值稳定性做过太多优化和测试(我也确实不擅长这个,对不住)。直接两个box1作输入的corner case之前确实没有考虑过。

如果你能提供一些测试代码,证明你的实现确实比原来的更稳健,欢迎提一个pull request,我可以把你的实现merge进来。