Problems with `_BaseCouplingBlock` for grayscale inputs
kvarun95 opened this issue · 2 comments
It looks like splitting is not performed in the _BaseCouplingBlock
properly when its channels
attribute is equal to 1, leading to errors when input is passed to the subnet later on. The offending lines are the following (in the definition of _BaseCouplingBlock
in modules.coupling_layers.py
):
self.split_len1 = self.channels // 2
self.split_len2 = self.channels - self.channels // 2
I encountered this problem as well. A simple way around this is to apply a downsampling layer (e.g. Haar downsampling, from reshapes.py) before adding the coupling block - this will introduce extra channels, by converting the spatial dimensions to channels, which can then be split by the coupling layer.
See Guided Image Generation with Conditional Invertible Neural Networks (by the authors of this repository!) for more details.
Alternatively, you could edit the code of _BaseCouplingBlock such that the splitting occurs across the spatial dimensions in a checkerboard pattern (like in RealNVP) as opposed to across channels (what occurs here, like in Glow). In coupling_layers.py, the code comments recommend to 'prepend an i_RevNet_downsampling module' but I haven't personally tried this.
Glow: https://arxiv.org/abs/1807.03039
RealNVP: https://arxiv.org/abs/1605.08803
Thanks! Adding a downsampling layer does sound like the simplest way out.