vislearn/FrEIA

How to calculate loss for multiple outputs?

manuelknott opened this issue · 4 comments

It is not clear to me, how to calculate the loss when I have a Invertible Graph with multiple output nodes.
For example, I implemented Glow with a multi-scale architecture where each z_i is a separate output node.
Based on the examples, I only see how to calculate the loss for a single output node.

Thank you very much in advance!

Thanks for your question. Let's assume your output is as follows:

(a, b), jac = inn(x)

Then, you can compute the standard maximum likelihood loss as:

loss = (a ** 2).sum(-1) + (b ** 2).sum(-1) - jac

Does this answer your question?

Dear @fdraxler ,
thank you for the quick answer. I think I did not express myself precisely enough so I try it with an example. In the multi-scale architecture as for example used in glow, I have several Zs. Given 3-channel images with a resolution of 128x128 as inputs this results in the following tensor shapes considering the Squeeze and Split operations (L=5):

X.shape = (B, 3, 128, 128)
Z1.shape = (B, 6, 64, 64)
Z2.shape = (B, 12, 32, 32)
Z3.shape = (B, 24, 16, 16)
Z4.shape = (B, 48, 8, 8)
Z5.shape = (B, 192, 4, 4)

Your formula does not apply to that without prior reshaping.

I saw other implementations a) evaluating each Z against the prior individually and passing the resulting log_det_jac until the last output; or b) Reshape and concatenate all Z's in the end and calculate the loss from there.

It is unclear to me what's the most efficient way to calculate the loss and if the split block should forward any logdet in the latter case. I apologize if this is more of a general conceptual question rather than specific to your framework. However, a small example of what's the best/recommended way to implement a multiscale approach in FrEIA would help a lot.

Thank you very much for your help.

I think the two variants a and b are equivalent mathematically as long as you add up all the losses. Reshaping and concatenating is probably the easiest choice:

stacked = torch.cat([Z.reshape(B, -1) for Z in [Z1, Z2, Z3, Z4, Z5]], 1) # B x 3*128*128
loss_per_item = (stacked ** 2).sum(1) / 2 - jac
loss_per_batch = loss_per_item.mean(0)

And some people divide the loss by the number of dimensions, so you get mean in the second line.

Alright got it. Thank you very much for the explanation.