Whu-wxy/Non-local-U-Nets-2D-block

关于multi_head_attention_2d()的参数设置

Closed this issue · 4 comments

谢谢你的分享!请问在
multi_head_attention_2d(torch.nn.Module):
def init(self, in_channel, key_filters, value_filters,
output_filters, num_heads, dropout_prob=0.5, layer_type='SAME'):
中,key_filters, value_filters, num_heads的值如何确定?

num_heads能被key_filters和value_filters整除就行

谢谢回复!还有一个问题key_filters和value_filters这两个参数也是经验值吗?和input channel and output channel有关吗?这两个值对计算量影响蛮大

这两个参数文中没有具体说明,要看实验效果了。可以比input channel小一些,减少参数量