photosynthesis-team/piq

Wrong total variation calculation

Dobatymo opened this issue · 2 comments

The total variation (l2 version) is calculated here as sqrt(sum(d_w**2 + d_h**2)). Shouldn't it be sum(sqrt(d_w**2 + d_h**2)) instead? See

piq/piq/tv.py

Lines 34 to 37 in 26d044e

elif norm_type == 'l2':
w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3])
h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3])
score = torch.sqrt(h_variance + w_variance)

Now the problem is how to vectorize this correctly...

Hi @Dobatymo
I believe that both variants are equally common in the literature. Wikipedia article has the summation outside, while other sources (see image) put it inside. We have exact formula included in the docs so user can decide if it satisfies his use case or not.
image

Feel free to close the issues if it answers your question!

Hi @zakajd Sorry I missed the formula in the docs. However both Wikipedia and the two references from the docs have the sum outside. I am not familiar with any formulation which has the sum inside. I am only familiar with the isotropic and anisotropic formulations. However both have the sum outside (well it only matters for the isotropic version). Only the sum of the per pixel norm differs.

EDIT: I would suggest

d_w = x[:, :, :-1, 1:] - x[:, :, :-1, :-1]
d_h = x[:, :, 1:, :-1] - x[:, :, :-1, :-1]
score = torch.sum(torch.sqrt(torch.square(d_w) + torch.square(d_h)), dim=(1, 2, 3))

For l2_squared, it doesn't really matter as well.