normalizing flows questions
fkwlqm opened this issue · 2 comments
Tutorial: -> 11
Describe the bug
I do not understand the values of ldj
for the following code snippet @dequantization .
In the code below i have [numbered]
some lines for ease of reference.
class Dequantization(nn.Module):
def __init__(self, alpha=1e-5, quants=256):
"""
Inputs:
alpha - small constant that is used to scale the original input.
Prevents dealing with values very close to 0 and 1 when inverting the sigmoid
quants - Number of possible discrete values (usually 256 for 8-bit image)
"""
super().__init__()
self.alpha = alpha
self.quants = quants
def forward(self, z, ldj, reverse=False):
if not reverse:
z, ldj = self.dequant(z, ldj)
z, ldj = self.sigmoid(z, ldj, reverse=True)
else:
z, ldj = self.sigmoid(z, ldj, reverse=False)
z = z * self.quants
ldj += np.log(self.quants) * np.prod(z.shape[1:])
z = torch.floor(z).clamp(min=0, max=self.quants-1).to(torch.int32)
return z, ldj
def sigmoid(self, z, ldj, reverse=False):
# Applies an invertible sigmoid transformation
if not reverse:
ldj += (-z-2*F.softplus(-z)).sum(dim=[1,2,3]) --------- [5]
z = torch.sigmoid(z)
# Reversing scaling for numerical stability
ldj -= np.log(1 - self.alpha) * np.prod(z.shape[1:])
z = (z - 0.5 * self.alpha) / (1 - self.alpha)
else:
z = z * (1 - self.alpha) + 0.5 * self.alpha # Scale to prevent boundaries 0 and 1
ldj += np.log(1 - self.alpha) * np.prod(z.shape[1:])
ldj += (-torch.log(z) - torch.log(1-z)).sum(dim=[1,2,3]) --------------- [4]
z = torch.log(z) - torch.log(1-z)
return z, ldj
def dequant(self, z, ldj):
# Transform discrete values to continuous volumes
z = z.to(torch.float32)
z = z + torch.rand_like(z).detach()
z = z / self.quants
ldj -= np.log(self.quants) * np.prod(z.shape[1:])
return z, ldj
Let us start with the smallest function dequant
:
def dequant(self, z, ldj):
# Transform discrete values to continuous volumes
z = z.to(torch.float32)
z = z + torch.rand_like(z).detach() # ------ [1]
z = z / self.quants # ---------- [2]
ldj -= np.log(self.quants) * np.prod(z.shape[1:]) # ----------- [3]
return z, ldj
First, line [1] is converting discrete z
to continuous z
by adding random uniform noise. We can do that because a few lines above we prove they are essentially the same distributions. ( p(x) == E(p(x+u)) if u~U(0,1]
)
Second, the line [2] is dividing z by quants (self explanatory )
Third, the line [3] calculates log-det-jacobian
(ldj). I think ldj will be simply log(1/quants)
or -log(quants)
. What is prod(z.shape[1:]) doing over there? Then, why is it not present in lines [4] and [5]?
Additionally, I removed those extra "z"s in ldjs and ran the dequant <-> quant verification, and it still worked.
nvm, i am slightly braindead sometimes.