Hard to understand mean and logs in Split2d Module
Closed this issue · 2 comments
class Split2d(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.conv = Conv2dZeros(num_channels // 2, num_channels)
def split2d_prior(self, z):
h = self.conv(z) # enlarge the channel number of z by doubling it
return thops.split_feature(h, "cross") #split channel by odd an even
def forward(self, input, logdet=0., reverse=False, eps_std=None):
if not reverse:
z1, z2 = thops.split_feature(input, "split") # split channel by first half and second half
mean, logs = self.split2d_prior(z1)
logdet = GaussianDiag.logp(mean, logs, z2) + logdet
return z1, logdet
else:
z1 = input
mean, logs = self.split2d_prior(z1)
z2 = GaussianDiag.sample(mean, logs, eps_std)
z = thops.cat_feature(z1, z2)
return z, logdet
This is the Split2d class in glow/modules.py,
and my question is on the forward function.
mean, logs = self.split2d_prior(z1)
As I understand, self.split2d_prior() is only splitting on the channel by odd and even index. But the returned values are called mean and logs (log of square root of variance in this project).
Why is that? Why just splitting an image on the channel by its channel index can create 2 pieces with totally different meaning? Why can one be treated as mean and the other as logs?
Hi, @Schwartz-Zha:
As far as I am concerned, the channels of input feature map have been non-linearly projected many times by previous layers.
Thus, the feature maps on different channels can represent totally different things. It's no more RGB channels of an image, which are supposed to be strongly relative.
Besides, the channels are computed by different weights
in previous layers. Actually, one can split the channels by halves into mean
and logs
, rather than odd and even index, if preferred. Equivalently, one can use two modules to get mean
and logs
separately. It depends on your own coding style.
Following codes may explain my opinion:
z = torch.rand(1, 10)
# use two separate layer
fc0 = nn.Linear(10, 50)
fc1 = nn.Linear(10, 50)
mean = fc0(z)
logs = fc1(z)
# use one layer, but split result later.
fc01 = nn.Linear(10, 100)
fc01.weights.data[:50, :] = fc0.weights.data # the first half are weights of fc0
fc01.weights.data[50:, :] = fc1.weights.data # the remained are weights of fc1
mean, logs = fc01(z).chunk(chunks=2, dim=1)
Thank you for your patience. I understand this operation is very common in VAE. This is a fantastic repo and a very good starting point on normalizing flow. I'm learning every detail about normalizing flow from this
and may have more questions in future (if you wouldn't mind).