sotiraslab/AgileFormer

A problem with nat_2d.py

williamsriver opened this issue · 2 comments

Hello, I encountered an issue while using your project.

In the nat_2d.py file, the following code is used to import functions from the NAT module:

from natten.functional import na2d_av, na2d_qk_with_bias

I am using natten version 0.17.0, but the natten.functional module does not contain the na2d_qk_with_bias function. I attempted to use the following import in the agileFormer_sys_2d.py file:

from natten.na2d import NeighborhoodAttention2D
I would like to confirm if this approach would affect the functionality of the program, or if there is a better solution.

My Conda Virtual Environment:

python = 3.8
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
absl-py                   2.1.0                    pypi_0    pypi
adamp                     0.3.0                    pypi_0    pypi
ca-certificates           2024.3.11            h06a4308_0  
certifi                   2024.6.2                 pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
cmake                     3.20.3                   pypi_0    pypi
contextlib2               21.6.0                   pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
filelock                  3.14.0                   pypi_0    pypi
fsspec                    2024.6.0                 pypi_0    pypi
h5py                      3.11.0                   pypi_0    pypi
huggingface-hub           0.23.3                   pypi_0    pypi
idna                      3.7                      pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
markupsafe                2.1.5                    pypi_0    pypi
medpy                     0.5.1                    pypi_0    pypi
ml-collections            0.1.1                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
natten                    0.17.1+torch230cu121          pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.1                      pypi_0    pypi
numpy                     1.24.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.5.40                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openssl                   3.0.13               h7f8727e_2  
packaging                 24.0                     pypi_0    pypi
pillow                    10.3.0                   pypi_0    pypi
pip                       24.0             py38h06a4308_0  
python                    3.8.19               h955ad1f_0  
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
requests                  2.32.3                   pypi_0    pypi
safetensors               0.4.3                    pypi_0    pypi
scipy                     1.10.1                   pypi_0    pypi
setuptools                69.5.1           py38h06a4308_0  
simpleitk                 2.3.1                    pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.12.1                   pypi_0    pypi
timm                      1.0.3                    pypi_0    pypi
tk                        8.6.14               h39e8969_0  
torch                     2.3.0+cu121              pypi_0    pypi
torchaudio                2.3.0+cu121              pypi_0    pypi
torchvision               0.18.0+cu121             pypi_0    pypi
tqdm                      4.66.4                   pypi_0    pypi
triton                    2.3.0                    pypi_0    pypi
tvdcn                     0.5.0                    pypi_0    pypi
typing-extensions         4.12.1                   pypi_0    pypi
urllib3                   2.2.1                    pypi_0    pypi
wheel                     0.43.0           py38h06a4308_0  
xz                        5.4.6                h5eee18b_1  
yacs                      0.1.8                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_1  

@williamsriver I believe you were using the latest natten version, where "na2d_qk_with_bias" has been replaced with "na2d_qk". Simply replacing "na2d_qk_with_bias" with "na2d_qk" should resolve this problem (but I haven't tested it on this version). Let me know if you encounter further issues.

Yes! I did use your method to solve the problem.
attn = na2d_qk(q, k, self.kernel_size, self.dilation, rpb=self.rpb)