wfs123456/CCTrans

Error on train.py - ShanghaiTech/part_A/

Closed this issue · 17 comments

Error when training on ShanghaiTech/part_A/.

I have cloned your repo, made a small fix (the utils folder was not in the repo, in train.py you do from train_helper_PVT import Trainer but the name I think is train_helper_ALTGVT.py, I can make a Pull later on to fix this)
Then i execute train.py, but as soon the code reaches the model creations, I get:

number of img: 300
number of img: 182
Traceback (most recent call last):
  File "train.py", line 63, in <module>
    trainer.setup()
  File "/train_folder/head_detection/CCTrans/train_helper_ALTGVT.py", line 77, in setup
    self.model = ALTGVT.alt_gvt_large(pretrained=True)
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 549, in alt_gvt_large
    model = ALTGVT(
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 496, in __init__
    super(ALTGVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads,
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 483, in __init__
    super(PCPVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads,
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 417, in __init__
    super(CPVTV2, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios,
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 309, in __init__
    _block = nn.ModuleList([
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 310, in <listcomp>
    block_cls(
  File "/train_folder/head_detection/CCTrans/Networks/ALTGVT.py", line 236, in __init__
    super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop,
TypeError: __init__() takes from 3 to 10 positional arguments but 11 were given

Ok, this error goes away with timm==0.3.2 (I was using the latest timm 0.4.12 version)

After having fixed this, I have issue with another piece of code ( torch '1.10.0+cu113' version):

 File "/root/miniconda3/envs/head_detection/lib/python3.8/site-packages/timm/models/layers/helpers.py", line 6, in <module>
    from torch._six import container_abcs
ImportError: cannot import name 'container_abcs' from 'torch._six' (/root/miniconda3/envs/head_detection/lib/python3.8/site-packages/torch/_six.py)

The fix to this issue is to upgrade timm, but if I do this I have the positional arguments issue

The error is in the qk_scale parameter.
Your code:

super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop,

Timm: https://github.com/rwightman/pytorch-image-models/blob/f7d210d759beb00a3d0834a3ce2d93f6e17f3d38/timm/models/vision_transformer.py#L213

Thanks for the response. However, with timm=0.3.2 I have another error (cannot import name 'container_abcs' from 'torch._six')

Unfortunately, I must use the last version of torch in order to use the last gen gpu

This was not a environment problem, but I fixed it in the pull request. Also, there is a way I can contact you? I have a few question about the implementation, I think that some parts are not coherent with the original paper. Hope to hear you soon

I would, but I'm unable to find you email :)

I understand why I cannot view your email. When you write some address in the content of the issue email, as you did, GitHub will replace the address with @.***. Take a look at your message in the GitHub UI. You can find my address in my profile, please send an email so we can continue the discussion of the paper.

OH, THIS IS WHAT I SEE, I CAN'T VIEW YOUR EMAIL!
My Tencent email:  @.***,please take out all the '-'

Thank you! This is my email developer email: francescotdev@gmail.com