Use buffers instead of parameters in the Downsample layer
sarlinpe opened this issue · 2 comments
Thank you for sharing this great work and for open-sourcing your code!
I think that the things to watch out for could be avoided by defining the weights in the Downsample layer as buffers instead of parameters. See torch.nn.Module.register_buffer
.
Thanks a lot for the suggestion! I integrated the change. Downsample
now has buffer self.filt
and calls functional F.conv2d
at the forward pass. We don't have to worry about it being overwritten or changing during training, so I removed those warnings from the readme.
The nn.Conv2d
is still in the Downsample
layer for legacy reasons -- optimizer.load_state_dict( )
doesn't play well if I remove it -- but the nn.Conv2d
is not used at all, and it's parameters add minimal memory to the saved files.
Looks good, thanks for having addressed that!