LeapLabTHU/DAT

How to run the model

SITUSITU opened this issue · 1 comments

if name == "main":
x = torch.ones((2, 3, 224, 224))
model = DAT()
y = model(x)
print(y.shape)

I tried to run the model with the above code to learn its details, but the following error occurred.

File "Model\DAT\DAT.py", line 232, in
model = DAT()
File "Model\DAT\DAT.py", line 134, in init
use_dwc_mlps[i])
File "Model\DAT\DAT.py", line 59, in init
no_off, fixed_pe, stage_idx)
File "Model\DAT\DAT_Block.py", line 201, in init
nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, kk // 2, groups=self.n_group_channels),
File "env\lib\site-packages\torch\nn\modules\conv.py", line 446, in init
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
File "env\lib\site-packages\torch\nn\modules\conv.py", line 132, in init
(out_channels, in_channels // groups, *kernel_size), **factory_kwargs))
RuntimeError: Trying to create tensor with negative dimension -96: [-96, 1, 9, 9]

Process finished with exit code 1

Hello @SITUSITU,

This short test code may not work well without proper configurations. If you want to reproduce the DAT conference version, please checkout the config files in commit 566a593 to correctly build the models.

In the conference version of DAT, the first two stages contain no DMHA blocks, so I set a -1 to avoid misuse. By the way, a new DAT++ with an extended paper has come out, and we welcome you to have a try.