usuyama/pytorch-unet

Use ConvTranspose2d instead

imadtyx opened this issue · 2 comments

Here in the code of U-Net you have used upsampling layer. Instead of it you should be using ConvTranspose2d.

UpSampling2D is just a simple scaling up of the image by using nearest neighbour or bilinear upsampling, so nothing smart. Advantage is it's cheap.

Conv2DTranspose is a convolution operation whose kernel is learnt (just like normal conv2d operation) while training your model. Using Conv2DTranspose will also upsample its input but the key difference is the model should learn what is the best upsampling for the job.

Thanks for a great suggestion! Please feel free to submit a PR.

@imadtyx Can I ask you a question: is the second x2 really required in the implementation there? https://github.com/milesial/Pytorch-UNet/blob/2f62e6b1c8e98022a6418d31a76f6abd800e5ae7/unet/unet_parts.py#L56C28-L56C28

I'm doing an experiment about checkpointing and storing less input would be better.