Justin1904/Low-rank-Multimodal-Fusion

is the implementation consistent with the math?

Closed this issue · 1 comments

having a hard time matching the math with the code:

self.audio_factor = Parameter(torch.Tensor(self.rank, self.audio_hidden + 1, self.output_dim))
self.video_factor = Parameter(torch.Tensor(self.rank, self.video_hidden + 1, self.output_dim))
self.text_factor = Parameter(torch.Tensor(self.rank, self.text_out + 1, self.output_dim))
self.fusion_weights = Parameter(torch.Tensor(1, self.rank))
self.fusion_bias = Parameter(torch.Tensor(1, self.output_dim))
......
......
fusion_audio = torch.matmul(_audio_h, self.audio_factor)
fusion_video = torch.matmul(_video_h, self.video_factor)
fusion_text = torch.matmul(_text_h, self.text_factor)
fusion_zy = fusion_audio * fusion_video * fusion_text

# output = torch.sum(fusion_zy, dim=0).squeeze()
# use linear transformation instead of simple summation, more flexibility
output = torch.matmul(self.fusion_weights, fusion_zy.permute(1, 0, 2)).squeeze() + self.fusion_bias

here is my understanding

  • self.video_hidden = $d_m$ in the paper, i.e., the dimension of the video embedding, similar for other modality
  • self.output_dim = $d_h$ in the paper

_audio_h is of dim (batch_size, self.audio_hidden + 1),
self.audio_factor is of dim (self.rank, self.audio_hidden + 1, self.output_dim),
so torch.matmul(_audio_h, self.audio_factor) is of dim (self.rank, batch_size, self.output_dim), which is also the dimension of fusion_zy, then fusion_zy goes through a linear transformation to collapse the "rank" dimension and get output. Is output the h in the paper? It seems the implementation following Eq (6) in the paper, but not exactly.
(1) the summation is replaced by a weighted summation
(2) the product and summation order is exchanged. Eq (6) first does the summation then product, but the code reversed this.
Did I understand it correctly? seems these changes make it not equivalent to the math? What are the rationale of these changes?

Thanks!

Yes, your interpretation of the code is correct. These are very good observations.

For 1), the weighted sum makes it NOT equivalent to the math strictly, that is true indeed. This is actually an artifact from an arbitrary attempt to make LMF more "parameterized". As you can see in the code you posted above, the alternative is to do the simple sum as commented out. In my experiments, I did not observe much difference. I'd assume if you remove those fusion_weights and fusion_bias, you can probably make up for the model capacity by adding an additional linear layer in the downstream network (though that is still not mathematically equivalent to this).

For 2), please refer to equation 8) in the paper. We actually started with 8) and hence the code, and later added 6) just to explain how we come from equation 5) to 8). You can try verify 5) == 8) by breaking it down to elementwise equations but I'll skip the details here. Here's also an empirical example to show this.