ganslate-team/ganslate

Val-Test metrics need to work accurately across batches and across data points

Closed this issue · 4 comments

With the current implementation if the batch size is 2 and data points are 5:

The final averaging is done in this way,
(average of first batch metrics with 2 data points + average of second batch metrics with 2 data points + third batch metric with only one data point ) / 3

For batch size 1, the implementation works this way,
( first data point metric + .... + 5th data point metric)/5

These values are not synonymous and need to be fixed so that the same values are given no matter what the implementation detail.

this is without DDP? do you think that torchmetrics would take care of that?

Not dependent on DDP really, its about how we aggregate metrics. Not sure if torch metrics would take care of it, the issue lies within how we do metrics over multiple data points. I've fixed it for now, I'll push the update in a bit.

the issue lies within how we do metrics over multiple data points

you mean how we deal with data points in the batch, no? if that's the case, we can actually get rid of how we do it and just use torchmetrics since it handles batches anyways

Nope, I mean doing it over, for example, multiple patients - then we average over each patient, if we already average over batch, then the final average will be incorrect.