InterDigitalInc/CompressAI

refactor: simpler calculation of likelihoods

YodaEmbedding opened this issue · 1 comments

Regarding:

def _likelihood(self, inputs: Tensor) -> Tensor:
half = float(0.5)
v0 = inputs - half
v1 = inputs + half
lower = self._logits_cumulative(v0, stop_gradient=False)
upper = self._logits_cumulative(v1, stop_gradient=False)
sign = -torch.sign(lower + upper)
sign = sign.detach()
likelihood = torch.abs(
torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
)
return likelihood

The following:

sign = -torch.sign(lower + upper)
sign = sign.detach()
likelihood = torch.abs(
    torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
)

...seems to be equivalent to:

likelihood = torch.abs(torch.sigmoid(upper) - torch.sigmoid(lower))

Furthermore, since _logits_cumulative is guaranteed to be a monotonically increasing scalar function on R -> R, we must have that lower <= upper. Thus, we can remove the abs altogether:

likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)

Unless I'm mistaken, there should be no difference in numerical stability either.


"Proof" by experimentation:

import torch

def likelihood_old(lower, upper):
    sign = -torch.sign(lower + upper)
    sign = sign.detach()
    likelihood = torch.abs(
        torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
    )
    return likelihood

def likelihood_new(lower, upper):
    likelihood = torch.abs(torch.sigmoid(upper) - torch.sigmoid(lower))
    return likelihood

>>> n = 100
>>> lower, upper = (torch.randn(n) * 10**torch.randn(n), torch.randn(n) * 10**torch.randn(n))
>>> torch.isclose(f_old(lower, upper), f_new(lower, upper))
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

NOTE: I haven't checked if the gradients are the same. I should probably do that...

Hi.
Correct, the current implementation follows the original from Tensorflow compression (which has changed, since). No issues with back propagation with the sign? ( sign = sign.detach() )