refactor: simpler calculation of likelihoods
YodaEmbedding opened this issue · 1 comments
Regarding:
CompressAI/compressai/entropy_models/entropy_models.py
Lines 458 to 469 in 1259aff
The following:
sign = -torch.sign(lower + upper)
sign = sign.detach()
likelihood = torch.abs(
torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
)
...seems to be equivalent to:
likelihood = torch.abs(torch.sigmoid(upper) - torch.sigmoid(lower))
Furthermore, since _logits_cumulative
is guaranteed to be a monotonically increasing scalar function on R -> R
, we must have that lower <= upper
. Thus, we can remove the abs altogether:
likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
Unless I'm mistaken, there should be no difference in numerical stability either.
"Proof" by experimentation:
import torch
def likelihood_old(lower, upper):
sign = -torch.sign(lower + upper)
sign = sign.detach()
likelihood = torch.abs(
torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
)
return likelihood
def likelihood_new(lower, upper):
likelihood = torch.abs(torch.sigmoid(upper) - torch.sigmoid(lower))
return likelihood
>>> n = 100
>>> lower, upper = (torch.randn(n) * 10**torch.randn(n), torch.randn(n) * 10**torch.randn(n))
>>> torch.isclose(f_old(lower, upper), f_new(lower, upper))
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True])
NOTE: I haven't checked if the gradients are the same. I should probably do that...
Hi.
Correct, the current implementation follows the original from Tensorflow compression (which has changed, since). No issues with back propagation with the sign? ( sign = sign.detach() )