chaiyujin/glow-pytorch

Hard to understand mean and logs in Split2d Module

Closed this issue · 2 comments

class Split2d(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv = Conv2dZeros(num_channels // 2, num_channels)

    def split2d_prior(self, z):
        h = self.conv(z) # enlarge the channel number of z by doubling it
        return thops.split_feature(h, "cross") #split channel by odd an even

    def forward(self, input, logdet=0., reverse=False, eps_std=None):
        if not reverse:
            z1, z2 = thops.split_feature(input, "split") # split channel by first half and second half
            mean, logs = self.split2d_prior(z1)
            logdet = GaussianDiag.logp(mean, logs, z2) + logdet
            return z1, logdet
        else:
            z1 = input
            mean, logs = self.split2d_prior(z1)
            z2 = GaussianDiag.sample(mean, logs, eps_std)
            z = thops.cat_feature(z1, z2)
            return z, logdet

This is the Split2d class in glow/modules.py,
and my question is on the forward function.

mean, logs = self.split2d_prior(z1)

As I understand, self.split2d_prior() is only splitting on the channel by odd and even index. But the returned values are called mean and logs (log of square root of variance in this project).

Why is that? Why just splitting an image on the channel by its channel index can create 2 pieces with totally different meaning? Why can one be treated as mean and the other as logs?

Hi, @Schwartz-Zha:

As far as I am concerned, the channels of input feature map have been non-linearly projected many times by previous layers.
Thus, the feature maps on different channels can represent totally different things. It's no more RGB channels of an image, which are supposed to be strongly relative.
Besides, the channels are computed by different weights in previous layers. Actually, one can split the channels by halves into mean and logs, rather than odd and even index, if preferred. Equivalently, one can use two modules to get mean and logs separately. It depends on your own coding style.

Following codes may explain my opinion:

z = torch.rand(1, 10)
# use two separate layer
fc0 = nn.Linear(10, 50)
fc1 = nn.Linear(10, 50)
mean = fc0(z)
logs = fc1(z)

# use one layer, but split result later.
fc01 = nn.Linear(10, 100)
fc01.weights.data[:50, :] = fc0.weights.data  # the first half are weights of fc0
fc01.weights.data[50:, :] = fc1.weights.data  # the remained are weights of fc1
mean, logs = fc01(z).chunk(chunks=2, dim=1)

Thank you for your patience. I understand this operation is very common in VAE. This is a fantastic repo and a very good starting point on normalizing flow. I'm learning every detail about normalizing flow from this
and may have more questions in future (if you wouldn't mind).