Hard coded input sequence length to the transformer blocks with using use_tf_gamma = True
zhhhhahahaha opened this issue ยท 14 comments
Hi! Thanks for your amazing code. I am trying to use the pre-trained model but I found out that when I set the use_tf_gamma = True, I can only use the precomputed gamma positions for the input sequence of length 1536, will you fix that later?
Also, the sanity check will fail. After running this
python test_pretrained.py
Traceback (most recent call last):
File "/home/ubuntu/enformer-pytorch/test_pretrained.py", line 11, in <module>
corr_coef = enformer(
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 450, in forward
x = trunk_fn(x)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 144, in forward
return self.fn(x, **kwargs) + x
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/opt/conda/envs/seq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 269, in forward
positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma)
File "/home/ubuntu/enformer-pytorch/enformer_pytorch/modeling_enformer.py", line 123, in get_positional_embed
embeddings = torch.cat(embeddings, dim = -1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2047 but got size 3071 for tensor number 2 in the list.
The program will raise the problem I said above because the input sequence length for the transformer block is 1024 for the test sample.
@zhhhhahahaha ah bummer
yea i can fix it, though it will take up a morning
can you force it off for now? (by setting use_tf_gamma = False
)
@zhhhhahahaha want to see if 0.8.4 fixes it on your machine?
Thanks for your work!
But I think if we use input with different sequence lengths, we need to recompute the gamma position encoding because we have different
All in all, I think it is enough for using the pre-trained enformer's parameter with the same input sequence length, I will figure out myself if I need to use different input sequence lengths (maybe retrain the transformer block).
@zhhhhahahaha ah yea, we could expand the precomputed tf gammas for all sequence lengths from 1 - 1536, then index it out
i swear this is the last time i ever want to deal with tensorflow
@zhhhhahahaha if you have tensorflow environment installed and could get me that matrix, i can get this fixed in a jiffy
I haven't installed the tensorflow environment, and I decide to retrain the model or just ignore this small rounding errors, thanks!
@johahi what do you think?
I haven't installed the tensorflow environment, and I decide to retrain the model or just ignore this small rounding errors, thanks!
yea neither do i
the other option would be to check if https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.xlogy.html is equivalent to the tensorflow xlogy
then we use jax2torch
@lucidrains i'll try jax2torch, will let you know if that works!
i don't know if the model performs well when it is used with cropped sequences, so just the tf-gammas for the original length of 1536 were fine for my use case...
@lucidrains from quick tests it seems like jax and torch have the same xlogy implementation (result after xlogy is allclose between them, but not between jax and tf or pt and tf), so this won't help, unfortunately ๐
@johahi oh bummer, we'll just let the tf gamma hack work only for 1536 sequence length then
@johahi thanks for checking!