What does the 'xhyk,bvk,bqk->bhvq' mean???
jasscia18 opened this issue · 6 comments
What does this mean in the code, logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
What does the 'xhyk,bvk,bqk->bhvq' mean???
Please refer to this PyTorch Doc.
It says the lower case letters (indices) to be associated with each dimension of the operands and result.
I've seen this document, but I don't understand the 'xhyk,bvk,bqk->bhvq' is how to calculate .
How do you compute these three matrices to get logits
I've seen this document, but I don't understand the 'xhyk,bvk,bqk->bhvq' is how to calculate .
How do you compute these three matrices to get logits
einsum
is a generalized version of matrix multiplication using the indexes of axes.
https://rockt.github.io/2018/04/30/einsum and https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/ may help.
The einsum
you asked is the implementation of the low-rank bilinear pooling (eq 2 in the paper).
self.h_mat
, q_
and v_
correspond to P, U^T X, V^T Y.
For your information, x
and y
of 'xhyk,bvk,bqk->bhvq'
is dummy indexes for compatibility of the previous version of BAN.
@jasscia18 the features of boxes and the embeddings of entities.