phlippe/uvadlc_notebooks

normalizing flows questions

fkwlqm opened this issue · 2 comments

Tutorial: -> 11

Describe the bug
I do not understand the values of ldj for the following code snippet @dequantization .
In the code below i have [numbered] some lines for ease of reference.

class Dequantization(nn.Module):
  def __init__(self, alpha=1e-5, quants=256):
      """
      Inputs:
          alpha - small constant that is used to scale the original input.
                  Prevents dealing with values very close to 0 and 1 when inverting the sigmoid
          quants - Number of possible discrete values (usually 256 for 8-bit image)
      """
      super().__init__()
      self.alpha = alpha
      self.quants = quants

  def forward(self, z, ldj, reverse=False):
      if not reverse:
          z, ldj = self.dequant(z, ldj)
          z, ldj = self.sigmoid(z, ldj, reverse=True)
      else:
          z, ldj = self.sigmoid(z, ldj, reverse=False)
          z = z * self.quants
          ldj += np.log(self.quants) * np.prod(z.shape[1:])
          z = torch.floor(z).clamp(min=0, max=self.quants-1).to(torch.int32)
      return z, ldj

  def sigmoid(self, z, ldj, reverse=False):
      # Applies an invertible sigmoid transformation
      if not reverse:
          ldj += (-z-2*F.softplus(-z)).sum(dim=[1,2,3]) --------- [5]
          z = torch.sigmoid(z)
          # Reversing scaling for numerical stability
          ldj -= np.log(1 - self.alpha) * np.prod(z.shape[1:])
          z = (z - 0.5 * self.alpha) / (1 - self.alpha)
      else:
          z = z * (1 - self.alpha) + 0.5 * self.alpha  # Scale to prevent boundaries 0 and 1
          ldj += np.log(1 - self.alpha) * np.prod(z.shape[1:])
          ldj += (-torch.log(z) - torch.log(1-z)).sum(dim=[1,2,3]) --------------- [4]
          z = torch.log(z) - torch.log(1-z)
      return z, ldj

  def dequant(self, z, ldj):
      # Transform discrete values to continuous volumes
      z = z.to(torch.float32)
      z = z + torch.rand_like(z).detach()
      z = z / self.quants
      ldj -= np.log(self.quants) * np.prod(z.shape[1:])
      return z, ldj

Let us start with the smallest function dequant :

    def dequant(self, z, ldj):
        # Transform discrete values to continuous volumes
        z = z.to(torch.float32)
        z = z + torch.rand_like(z).detach()  # ------ [1]
        z = z / self.quants # ---------- [2]
        ldj -= np.log(self.quants) * np.prod(z.shape[1:]) # ----------- [3]
        return z, ldj

First, line [1] is converting discrete z to continuous z by adding random uniform noise. We can do that because a few lines above we prove they are essentially the same distributions. ( p(x) == E(p(x+u)) if u~U(0,1] )
Second, the line [2] is dividing z by quants (self explanatory )
Third, the line [3] calculates log-det-jacobian (ldj). I think ldj will be simply log(1/quants) or -log(quants). What is prod(z.shape[1:]) doing over there? Then, why is it not present in lines [4] and [5]?

Additionally, I removed those extra "z"s in ldjs and ran the dequant <-> quant verification, and it still worked.

nvm, i am slightly braindead sometimes.