zalandoresearch/pytorch-vq-vae

Intentional weight sharing in ResidualStack?

andreaskoepf opened this issue · 1 comments

The list passed to nn.ModuleList in the ResidualStack class ctor in vae.ipynb#L324 duplicates a reference to a single Residual object instance. Was this done intentionally?

self._layers = nn.ModuleList(
[Residual(in_channels, num_hiddens, num_residual_hiddens)] * self._num_residual_layers)

To create a new objects for each layer the code might be changed to:

self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

@andreaskoepf thanks for your observation... I think its a bug... I will fix it!