HarukiYqM/Non-Local-Sparse-Attention

关于代码实现 bucket_score 变量的细节疑惑?

Deep-imagelab opened this issue · 3 comments

对照着论文描述和作图,我一步步仔细调试了您的代码,您的代码写的非常好!
我这里有个疑问,就是关于 bucket_score 变量(如下 我贴了您的代码),它是求得的不同bucket之间的相关性权重,并在softmax归一化后用score表示,用于了与y_att_buckets矩阵相乘,这一步我很明白。
# unormalized attention score
raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) # [N, n_hashes, num_chunks, chunk_size, chunk_size*3]

    # softmax
    bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)
    score = torch.exp(raw_score - bucket_score)  # (after softmax)
    bucket_score = torch.reshape(bucket_score, [N, self.n_hashes, -1])
    
    # attention
    ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets)  # [N, n_hashes, num_chunks, chunk_size, C]
    ret = torch.reshape(ret, (N, self.n_hashes, -1, C*self.reduction))

我主要不明白的是后续的代码,以上求得的ret是multi-round的,需要将multi-round这一维融合起来才能得到最终输出NCHW尺寸的特征,我不太明白后续为什么要用bucket_score进行softmax归一化后加权求和呢?这个bucket_score是 “不同bucket之间的相关性权重”,这里再用来求解multi-round维度的加权求和(如下 我贴了您的代码),总感觉怪怪的。
# recover the original order
ret = torch.reshape(ret, (N, -1, C
self.reduction)) # [N, n_hashesHW,C]
bucket_score = torch.reshape(bucket_score, (N, -1,)) # [N,n_hashesHW]
ret = batched_index_select(ret, undo_sort) # [N, n_hashesHW,C]
bucket_score = bucket_score.gather(1, undo_sort) # [N,n_hashesHW]

    # weighted sum multi-round attention
    ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction))  # [N, n_hashes*H*W,C]
    bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
    probs = nn.functional.softmax(bucket_score, dim=1)
    ret = torch.sum(ret * probs, dim=1)

我个人觉得是,这里的multi-round,其实一定程度上是类似于Transformer中multi-head的,仿照它的操作,直接将multi-round维和channel维合并为multi-roundchannel,再用11 Conv映射到channel是不是应该更合理呢?

顺便还有一个细节问题,我还想请教一下,代码里有一句 hash_codes = hash_codes.detach()
为什么要加上这一句呢?detach()是用于切断梯度传播的,这里加上这一句是要切断此处的梯度传播吗?
我试过把这句注释掉,代码也能正常训练,求解。

Hi, 这里使用softmax的motivation是来源于hash操作的随机性。这段代码的意思是衡量每个round所分到bucket元素之间的亲和力,如果分到的bucket元素更related,那么那个round的权重更大。supplementary里有visualization,可以参考理解。1x1conv 是learnable的,我认为应该也是可行的。detach应该可以去掉,hashcode并没有gradient回传,感谢指出。

Hi, 这里使用softmax的motivation是来源于hash操作的随机性。这段代码的意思是衡量每个round所分到bucket元素之间的亲和力,如果分到的bucket元素更related,那么那个round的权重更大。supplementary里有visualization,可以参考理解。1x1conv 是learnable的,我认为应该也是可行的。detach应该可以去掉,hashcode并没有gradient回传,感谢指出。

好的,感谢回复