lulab/BATTER

error loading state_dict for TerminatorTagger

Closed this issue · 3 comments

Hi, Thanks for developing this software, it sounds really useful for our research. I am trying to get it up and running and seem to be getting a weird error. I installed all of the prerequisites in a conda environment, then I tried running it on a cpu (it wouldnt recognize our GPU for some reason) with the include example data.

However, I get the following error:

./scripts/batter --fasta ./examples/S.aureus/GCF_000013425.1_ASM1342v1_genomic.fna --output ../staph.batter.out.bed --device cpu -rc -v
[2023-10-30 16:50:48,672] [tagging terminators] Initialize the model ...
[2023-10-30 16:50:48,770] [tagging terminators] Load model paramters from model/batter.mdl.pt ...
Traceback (most recent call last):
File "pathredacted/batter/batter/./scripts/batter", line 188, in
main()
File "pathredacted/batter/batter/./scripts/batter", line 112, in main
tagger.load_state_dict(state_dict)
File "pathredacted/opt/envs/batter/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TerminatorTagger:
Unexpected key(s) in state_dict: "encoder.embeddings.position_ids".

Any ideas or suggestions on what may be going wrong? Thanks!

Hi Alexandra, if you are using the latest transformers package, I think conda install -c conda-forge transformers==4.18.0 should fix the problem. It seems latter longformer versions have a different implementation for the "position_ids" variable and no longer keep it in the model state dict. Thanks for helping us improve our tool. (Also note that the inference on CPU can be slow and take some time to finish.)

Thanks! I tried that and unfortunately have a different error now:

./scripts/batter --fasta ./examples/S.aureus/GCF_000013425.1_ASM1342v1_genomic.fna --output ../staph.batter.out.bed --device cpu -rc -v
[2023-10-31 13:09:05,559] [tagging terminators] Initialize the model ...
[2023-10-31 13:09:05,683] [tagging terminators] Load model paramters from model/batter.mdl.pt ...
[2023-10-31 13:09:05,828] [tagging terminators] Will use cpu for inference ...
[2023-10-31 13:09:05,830] [tagging terminators] Load sequences from ./examples/S.aureus/GCF_000013425.1_ASM1342v1_genomic.fna ...
[2023-10-31 13:09:06,121] [tagging terminators] Intermediate result will be saved to ../staph.batter.out.bed.tmp ...
[2023-10-31 13:09:06,131] [tagging terminators] processing NC_007795.1 ...
/thefolderpath/batter/batter/scripts/crf.py:380: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1696859578619/work/aten/src/ATen/native/TensorCompare.cpp:493.)
score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score)
Traceback (most recent call last):
File "/thefolderpath//batter/batter/./scripts/batter", line 188, in
main()
File "/thefolderpath/batter/batter/./scripts/batter", line 158, in main
inference(tagger, batched_tokens, batched_ivs, args.top_k, args.device) #, args.temperature)
File "/thefolderpath/batter/batter/./scripts/batter", line 82, in inference
tags, probs = tagging(batched_tokens, tagger, nbest, temperature)
File "/thefolderpath//batter/batter/./scripts/batter", line 23, in tagging
tags, scores = model.crf.decode(logits, attention_mask, nbest=nbest)
File "/thefolderpath/batter/batter/scripts/crf.py", line 124, in decode
return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
File "/thefolderpath//batter/batter/scripts/crf.py", line 362, in _viterbi_decode_nbest
next_score = next_score.view(batch_size, -1, self.num_tags)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

I installed the dependencies with:
mamba install -n batter transformers==4.18.0 ushuffle pytorch cudatoolkit pyfaidx==0.7.1 numpy lightgbm bedtools pandas

and the environment has these versions:
conda list -n batter

Name Version Build Channel

_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
bedtools 2.31.0 hf5e1c6e_3 bioconda
biopython 1.81 py39hd1e30aa_1 conda-forge
brotli-python 1.1.0 py39h3d6467e_1 conda-forge
bzip2 1.0.8 h7f98852_4 conda-forge
c-ares 1.20.1 hd590300_1 conda-forge
ca-certificates 2023.7.22 hbcca054_0 conda-forge
certifi 2023.7.22 pyhd8ed1ab_0 conda-forge
charset-normalizer 3.3.1 pyhd8ed1ab_0 conda-forge
click 8.1.7 unix_pyh707e725_0 conda-forge
colorama 0.4.6 pyhd8ed1ab_0 conda-forge
cudatoolkit 9.2.148 h33e3169_12 conda-forge
dataclasses 0.8 pyhc8e2a94_3 conda-forge
filelock 3.13.1 pyhd8ed1ab_0 conda-forge
fsspec 2023.10.0 pyhca7485f_0 conda-forge
gmp 6.2.1 h58526e2_0 conda-forge
gmpy2 2.1.2 py39h376b7d2_1 conda-forge
huggingface_hub 0.18.0 pyhd8ed1ab_0 conda-forge
icu 73.2 h59595ed_0 conda-forge
idna 3.4 pyhd8ed1ab_0 conda-forge
importlib-metadata 6.8.0 pyha770c72_0 conda-forge
importlib_metadata 6.8.0 hd8ed1ab_0 conda-forge
jinja2 3.1.2 pyhd8ed1ab_1 conda-forge
joblib 1.3.2 pyhd8ed1ab_0 conda-forge
keyutils 1.6.1 h166bdaf_0 conda-forge
krb5 1.21.2 h659d440_0 conda-forge
ld_impl_linux-64 2.40 h41732ed_0 conda-forge
libabseil 20230802.1 cxx17_h59595ed_0 conda-forge
libblas 3.9.0 19_linux64_openblas conda-forge
libcblas 3.9.0 19_linux64_openblas conda-forge
libcurl 8.4.0 hca28451_0 conda-forge
libdeflate 1.18 h0b41bf4_0 conda-forge
libedit 3.1.20191231 he28a2e2_2 conda-forge
libev 4.33 h516909a_1 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-ng 13.2.0 h807b86a_2 conda-forge
libgfortran-ng 13.2.0 h69a702a_2 conda-forge
libgfortran5 13.2.0 ha4646dd_2 conda-forge
libhwloc 2.9.3 default_h554bfaf_1009 conda-forge
libiconv 1.17 h166bdaf_0 conda-forge
liblapack 3.9.0 19_linux64_openblas conda-forge
libnghttp2 1.55.1 h47da74e_0 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libopenblas 0.3.24 pthreads_h413a1c8_0 conda-forge
libprotobuf 4.24.3 hf27288f_1 conda-forge
libsqlite 3.43.2 h2797004_0 conda-forge
libssh2 1.11.0 h0841786_0 conda-forge
libstdcxx-ng 13.2.0 h7e041cc_2 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libuv 1.46.0 hd590300_0 conda-forge
libxml2 2.11.5 h232c23b_1 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
lightgbm 4.1.0 py39h3d6467e_2 conda-forge
llvm-openmp 17.0.3 h4dfa4b3_0 conda-forge
markupsafe 2.1.3 py39hd1e30aa_1 conda-forge
mkl 2022.2.1 h84fe81f_16997 conda-forge
mpc 1.3.1 hfe3b2da_0 conda-forge
mpfr 4.2.1 h9458935_0 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
ncurses 6.4 h59595ed_2 conda-forge
networkx 3.2.1 pyhd8ed1ab_0 conda-forge
numpy 1.26.0 py39h474f0d3_0 conda-forge
openssl 3.1.4 hd590300_0 conda-forge
packaging 23.2 pyhd8ed1ab_0 conda-forge
pandas 2.1.2 py39hddac248_0 conda-forge
pip 23.3.1 pyhd8ed1ab_0 conda-forge
pyfaidx 0.7.1 pyh5e36f6f_0 bioconda
pysam 0.22.0 py39hcada746_0 bioconda
pysocks 1.7.1 pyha2e5f31_6 conda-forge
python 3.9.18 h0755675_0_cpython conda-forge
python-dateutil 2.8.2 pyhd8ed1ab_0 conda-forge
python-tzdata 2023.3 pyhd8ed1ab_0 conda-forge
python_abi 3.9 4_cp39 conda-forge
pytorch 2.0.0 cpu_mkl_py39h1d6e76c_103 conda-forge
pytz 2023.3.post1 pyhd8ed1ab_0 conda-forge
pyvcf3 1.0.3 pyhdfd78af_0 bioconda
pyyaml 6.0.1 py39hd1e30aa_1 conda-forge
readline 8.2 h8228510_1 conda-forge
regex 2023.10.3 py39hd1e30aa_0 conda-forge
requests 2.31.0 pyhd8ed1ab_0 conda-forge
sacremoses 0.0.53 pyhd8ed1ab_0 conda-forge
scipy 1.11.3 py39h474f0d3_1 conda-forge
setuptools 68.2.2 pyhd8ed1ab_0 conda-forge
six 1.16.0 pyh6c4a22f_0 conda-forge
sleef 3.5.1 h9b69904_2 conda-forge
sympy 1.12 pypyh9d50eac_103 conda-forge
tbb 2021.10.0 h00ab1b0_2 conda-forge
tk 8.6.13 h2797004_0 conda-forge
tokenizers 0.12.1 py39h4d2953e_1 conda-forge
tqdm 4.66.1 pyhd8ed1ab_0 conda-forge
transformers 4.18.0 pyhd8ed1ab_0 conda-forge
typing-extensions 4.8.0 hd8ed1ab_0 conda-forge
typing_extensions 4.8.0 pyha770c72_0 conda-forge
tzdata 2023c h71feb2d_0 conda-forge
urllib3 2.0.7 pyhd8ed1ab_0 conda-forge
ushuffle 1.2.2 py39hf95cd2a_7 bioconda
wheel 0.41.3 pyhd8ed1ab_0 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
zipp 3.17.0 pyhd8ed1ab_0 conda-forge
zlib 1.2.13 hd590300_5 conda-forge
zstd 1.5.5 hfc55251_0 conda-forge

I figured it out. Once I pinned pytorch to 1.7.1 then it seems to work.

mamba install -n batter transformers==4.18.0 ushuffle pytorch==1.7.1 cudatoolkit pyfaidx==0.7.1 numpy lightgbm bedtools pandas

$ ./scripts/batter --fasta ./examples/S.aureus/GCF_000013425.1_ASM1342v1_genomic.fna --output ../staph.batter.out.bed --device cpu -rc -v
[2023-10-31 14:40:03,598] [tagging terminators] Initialize the model ...
[2023-10-31 14:40:03,756] [tagging terminators] Load model paramters from model/batter.mdl.pt ...
[2023-10-31 14:40:03,964] [tagging terminators] Will use cpu for inference ...
[2023-10-31 14:40:03,967] [tagging terminators] Load sequences from ./examples/S.aureus/GCF_000013425.1_ASM1342v1_genomic.fna ...
[2023-10-31 14:40:03,969] [tagging terminators] Intermediate result will be saved to ../staph.batter.out.bed.tmp ...
[2023-10-31 14:40:03,970] [tagging terminators] processing NC_007795.1 ...
[2023-10-31 14:41:10,435] [tagging terminators] 50 K bases processed .
[2023-10-31 14:42:34,473] [tagging terminators] 100 K bases processed .
[2023-10-31 14:43:58,994] [tagging terminators] 150 K bases processed .
[2023-10-31 14:45:23,226] [tagging terminators] 200 K bases processed .
[2023-10-31 14:46:45,255] [tagging terminators] 250 K bases processed .
[2023-10-31 14:48:08,107] [tagging terminators] 300 K bases processed .
[2023-10-31 14:49:29,570] [tagging terminators] 350 K bases processed .
[2023-10-31 14:50:49,366] [tagging terminators] 400 K bases processed .
[2023-10-31 14:52:08,022] [tagging terminators] 450 K bases processed .
[2023-10-31 14:53:27,886] [tagging terminators] 500 K bases processed .
[2023-10-31 14:54:27,529] [tagging terminators] 550 K bases processed .
[2023-10-31 14:55:47,565] [tagging terminators] 600 K bases processed .
[2023-10-31 14:57:07,705] [tagging terminators] 650 K bases processed .
[2023-10-31 14:58:26,574] [tagging terminators] 700 K bases processed .
[2023-10-31 14:59:46,106] [tagging terminators] 750 K bases processed .
[2023-10-31 15:01:05,631] [tagging terminators] 800 K bases processed .
[2023-10-31 15:02:23,984] [tagging terminators] 850 K bases processed .
[2023-10-31 15:03:41,554] [tagging terminators] 900 K bases processed .
[2023-10-31 15:05:00,565] [tagging terminators] 950 K bases processed .
[2023-10-31 15:06:19,870] [tagging terminators] 1000 K bases processed .
[2023-10-31 15:07:38,049] [tagging terminators] 1050 K bases processed .
[2023-10-31 15:08:36,795] [tagging terminators] 1100 K bases processed .
[2023-10-31 15:09:54,950] [tagging terminators] 1150 K bases processed .
[2023-10-31 15:11:13,797] [tagging terminators] 1200 K bases processed .
[2023-10-31 15:12:31,553] [tagging terminators] 1250 K bases processed .
[2023-10-31 15:13:50,240] [tagging terminators] 1300 K bases processed .
[2023-10-31 15:15:09,017] [tagging terminators] 1350 K bases processed .
[2023-10-31 15:16:26,162] [tagging terminators] 1400 K bases processed .
[2023-10-31 15:17:43,380] [tagging terminators] 1450 K bases processed .
[2023-10-31 15:19:00,784] [tagging terminators] 1500 K bases processed .
[2023-10-31 15:20:17,722] [tagging terminators] 1550 K bases processed .
[2023-10-31 15:21:14,941] [tagging terminators] 1600 K bases processed .
[2023-10-31 15:22:32,252] [tagging terminators] 1650 K bases processed .
[2023-10-31 15:23:48,642] [tagging terminators] 1700 K bases processed .
[2023-10-31 15:25:04,661] [tagging terminators] 1750 K bases processed .
[2023-10-31 15:26:21,092] [tagging terminators] 1800 K bases processed .
[2023-10-31 15:27:37,524] [tagging terminators] 1850 K bases processed .
[2023-10-31 15:28:52,824] [tagging terminators] 1900 K bases processed .
[2023-10-31 15:30:09,993] [tagging terminators] 1950 K bases processed .
[2023-10-31 15:31:27,303] [tagging terminators] 2000 K bases processed .
[2023-10-31 15:32:44,539] [tagging terminators] 2050 K bases processed .
[2023-10-31 15:34:01,534] [tagging terminators] 2100 K bases processed .
[2023-10-31 15:34:58,252] [tagging terminators] 2150 K bases processed .
[2023-10-31 15:36:13,160] [tagging terminators] 2200 K bases processed .
[2023-10-31 15:37:27,930] [tagging terminators] 2250 K bases processed .
[2023-10-31 15:38:44,115] [tagging terminators] 2300 K bases processed .
[2023-10-31 15:39:58,373] [tagging terminators] 2350 K bases processed .
[2023-10-31 15:41:12,828] [tagging terminators] 2400 K bases processed .
[2023-10-31 15:42:27,737] [tagging terminators] 2450 K bases processed .
[2023-10-31 15:43:41,445] [tagging terminators] 2500 K bases processed .
[2023-10-31 15:44:54,949] [tagging terminators] 2550 K bases processed .
[2023-10-31 15:46:08,054] [tagging terminators] 2600 K bases processed .
[2023-10-31 15:47:23,628] [tagging terminators] 2650 K bases processed .
[2023-10-31 15:48:18,842] [tagging terminators] 2700 K bases processed .
[2023-10-31 15:49:32,705] [tagging terminators] 2750 K bases processed .
[2023-10-31 15:50:47,304] [tagging terminators] 2800 K bases processed .
[2023-10-31 15:51:32,210] [tagging terminators] Sort predictions ...
[2023-10-31 15:51:36,243] [tagging terminators] Merge predictions ...
[2023-10-31 15:51:36,243] [tagging terminators] Final results will be saved to ../staph.batter.out.bed .
[2023-10-31 15:51:37,130] [select intervals] load intervals from ../staph.batter.out.bed.tmp ...
[2023-10-31 15:51:37,130] [select intervals] picked intervals will be saved to ../staph.batter.out.bed .
[2023-10-31 15:51:37,995] [tagging terminators] Remove temporary results ...
[2023-10-31 15:51:38,026] [tagging terminators] all done .