jnhwkim/ban-vqa

What does the 'xhyk,bvk,bqk->bhvq' mean???

jasscia18 opened this issue · 6 comments

What does this mean in the code, logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
What does the 'xhyk,bvk,bqk->bhvq' mean???

Please refer to this PyTorch Doc.
It says the lower case letters (indices) to be associated with each dimension of the operands and result.

I've seen this document, but I don't understand the 'xhyk,bvk,bqk->bhvq' is how to calculate .
How do you compute these three matrices to get logits

I've seen this document, but I don't understand the 'xhyk,bvk,bqk->bhvq' is how to calculate .
How do you compute these three matrices to get logits

einsum is a generalized version of matrix multiplication using the indexes of axes.
https://rockt.github.io/2018/04/30/einsum and https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/ may help.

The einsum you asked is the implementation of the low-rank bilinear pooling (eq 2 in the paper).
image
self.h_mat , q_ and v_ correspond to P, U^T X, V^T Y.

For your information, x and y of 'xhyk,bvk,bqk->bhvq' is dummy indexes for compatibility of the previous version of BAN.

@jasscia18 the features of boxes and the embeddings of entities.