undark-lab/swyft

Small difference between actual loss and defined loss

CHBKoenders opened this issue · 3 comments

Since loss is calculated from interleaving jointly drawn parameter-observation (doubling the effective batch size) pairs, there should be an additional factor 1/2 for the loss to approach the expectation value $$\mathbb{E}_{z \sim p(z), x \sim p(x\given{z}), z'\sim p(z')} \left[\ln(d(x,z)) + \ln(1-d(x,z'))\right$$.

loss = loss.sum(axis=0) / n_batch

bkmi commented

seems like a reasonable point, let me think about it to be sure. (I think you're right though)

bkmi commented

say n_batch = 2.

the loss from aalr method is:
l1 = BCE(theta1, x1, 1) + BCE(theta2, x1, 0)
l2 = BCE(theta2, x2, 1) + BCE(theta1, x2, 0)

The average across batches is therefore (l1 + l2) / n_batch = TOTAL_LOSS

--

our loss fn:
l1_us = BCE(theta1, x1, 1) + BCE(theta2, x1, 0) + BCE(theta2, x2, 1) + BCE(theta1, x2, 0)
l2_us = 0; this doesn't exist since we move everything over to the second dimension into groups of 4

(l1_us + l2_us) / 2 = (l1 + l2 + 0) / n_batch = l1 + l2 == TOTAL_LOSS

--

Is this wrong?

I looked into it in a little more detail and everything checks out after all.

My main mistake was in the fact that the view of lnL lnL.view(-1, 4, lnL.shape[-1]) changes the size of the batch dimension from 2*n_batch to n_batch/2. So dividing by n_batch at the end leaves you with correct average!