yingkaisha/keras-unet-collection

filter_num in TransUNet

parniash opened this issue · 6 comments

Hi,

In TransUNet, filter_num shows the number of filters for down and upsampling levels, right? However, If I use 3 filters filter_num=[64, 128, 128] instead of the default filter_num=[64, 128, 128, 512], the number of parameters of the network increases, and I get an OOM error. Is this a bug or am I missing something?

If you do this on GPUs, then a possible reason is that your configuration is too big. [64, 128, 128] --> [64, 128, 128, 256] adds a lot of weights.

The problem is that I go from [64, 128, 128, 256] --> [64, 128, 128] and the network gets larger! This doesn't make sense since I reduce the number of layers but the network gets bigger.

Have you been able to reproduce this issue?

@parniash Would you mind sharing your code? I don't think the network would get bigger.

If you compile these two models (filter_num is the difference):

model = models.transunet_2d((input_height, input_width, 3), filter_num=[32, 64, 128, 256], n_labels=n_classes, embed_dim=600, num_mlp=1000, num_heads=4, num_transformer=4, activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', batch_norm=True, pool=True, unpool='bilinear', backbone='ResNet50V2', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet')

vs

model = models.transunet_2d((input_height, input_width, 3), filter_num=[64, 128, 256], n_labels=n_classes, embed_dim=600, num_mlp=1000, num_heads=4, num_transformer=4, activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', batch_norm=True, pool=True, unpool='bilinear', backbone='ResNet50V2', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet')

The last one has more parameters. I don't understand why.

If I plug in: (128, 128, 3), n_labels=2

Model 1: Total params: 30,508,410
Model 2: Total params: 29,812,794

So your second configuration is smaller, there is no problem.

I feel that you are commenting on the number of trainable parameters---the second one has more trainable params, because its output head is connected to 64 channels.

Try dig into your configurations with model.summary(). The total size of a model can be reduced in many ways.