hwjiang1510/LEAP

Minimum GPU for inference ?

fefespn opened this issue · 3 comments

Hey,
I have a gpu with 23GB ram, and I get cuda outOfMemoryError.
what is the minimum ram I need ?

because the checkpoints you published isn't compatible with flashattention.

Thanks a lot !

Hi,

Thanks for your interest in our work.

We train and test with 40GB GPUs. I think 23GB is enough for inference, how did you set the evaluation config?

Hi,

In the demo_224_real.yaml file I set the use_flash_attn to True and use the kubric model you provided, But I get the error of Missing key here:

envs/leap/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LEAP:
Missing key(s) in state_dict: "encoder.transformers_cross.0.ln_1.weight", "encoder.transformers_cross.0.ln_1.bias", "encoder.transformers_cross.0.attn.Wq.weight", "encoder.transformers_cross.0.attn.Wq.bias", "encoder.transformers_cross.0.attn.Wkv.weight", "encoder.transformers_cross.0.attn.Wkv.bias", "encoder.transformers_cross.0.attn.out_proj.weight", "encoder.transformers_cross.0.attn.out_proj.bias", "encoder.transformers_cross.0.ln_2.weight", "encoder.transformers_cross.0.ln_2.bias", "encoder.transformers_cross.0.mlp.fc1.weight", "encoder.transformers_cross.0.mlp.fc1.bias", "encoder.transformers_cross.0.mlp.fc2.weight", "encoder.transformers_cross.0.mlp.fc2.bias", "encoder.transformers_cross.1.ln_1.weight", "encoder.transformers_cross.1.ln_1.bias", "encoder.transformers_cross.1.attn.Wq.weight", "encoder.transformers_cross.1.attn.Wq.bias", "encoder.transformers_cross.1.attn.Wkv.weight", "encoder.transformers_cross.1.attn.Wkv.bias", "encoder.transformers_cross.1.attn.out_proj.weight", "encoder.transformers_cross.1.attn.out_proj.bias", "encoder.transformers_cross.1.ln_2.weight", "encoder.transformers_cross.1.ln_2.bias", "encoder.transformers_cross.1.mlp.fc1.weight", "encoder.transformers_cross.1.mlp.fc1.bias", "encoder.transformers_cross.1.mlp.fc2.weight", "encoder.transformers_cross.1.mlp.fc2.bias", "encoder.transformers_self.0.ln_1.weight", "encoder.transformers_self.0.ln_1.bias", "encoder.transformers_self.0.attn.Wqkv.weight", "encoder.transformers_self.0.attn.Wqkv.bias", "encoder.transformers_self.0.attn.out_proj.weight", "encoder.transformers_self.0.attn.out_proj.bias", "encoder.transformers_self.0.ln_2.weight", "encoder.transformers_self.0.ln_2.bias", "encoder.transformers_self.0.mlp.fc1.weight", "encoder.transformers_self.0.mlp.fc1.bias", "encoder.transformers_self.0.mlp.fc2.weight", "encoder.transformers_self.0.mlp.fc2.bias", "encoder.transformers_self.1.ln_1.weight", "encoder.transformers_self.1.ln_1.bias", "encoder.transformers_self.1.attn.Wqkv.weight", "encoder.transformers_self.1.attn.Wqkv.bias", "encoder.transformers_self.1.attn.out_proj.weight", "encoder.transformers_self.1.attn.out_proj.bias", "encoder.transformers_self.1.ln_2.weight", "encoder.transformers_self.1.ln_2.bias", "encoder.transformers_self.1.mlp.fc1.weight", "encoder.transformers_self.1.mlp.fc1.bias", "encoder.transformers_self.1.mlp.fc2.weight", "encoder.transformers_self.1.mlp.fc2.bias", "lifting.transformer.0.ln_1.weight", "lifting.transformer.0.ln_1.bias", "lifting.transformer.0.self_attn.Wqkv.weight", "lifting.transformer.0.self_attn.Wqkv.bias", "lifting.transformer.0.ln_2.weight", "lifting.transformer.0.ln_2.bias", "lifting.transformer.0.cross_attn.Wq.weight", "lifting.transformer.0.cross_attn.Wq.bias", "lifting.transformer.0.cross_attn.Wkv.weight", "lifting.transformer.0.cross_attn.Wkv.bias", "lifting.transformer.0.cross_attn.out_proj.weight", "lifting.transformer.0.cross_attn.out_proj.bias", "lifting.transformer.0.ln_3.weight", "lifting.transformer.0.ln_3.bias", "lifting.transformer.0.mlp.fc1.weight", "lifting.transformer.0.mlp.fc1.bias", "lifting.transformer.0.mlp.fc2.weight", "lifting.transformer.0.mlp.fc2.bias", "lifting.transformer.1.ln_1.weight", "lifting.transformer.1.ln_1.bias", "lifting.transformer.1.self_attn.Wqkv.weight", "lifting.transformer.1.self_attn.Wqkv.bias", "lifting.transformer.1.ln_2.weight", "lifting.transformer.1.ln_2.bias", "lifting.transformer.1.cross_attn.Wq.weight", "lifting.transformer.1.cross_attn.Wq.bias", "lifting.transformer.1.cross_attn.Wkv.weight", "lifting.transformer.1.cross_attn.Wkv.bias", "lifting.transformer.1.cross_attn.out_proj.weight", "lifting.transformer.1.cross_attn.out_proj.bias", "lifting.transformer.1.ln_3.weight", "lifting.transformer.1.ln_3.bias", "lifting.transformer.1.mlp.fc1.weight", "lifting.transformer.1.mlp.fc1.bias", "lifting.transformer.1.mlp.fc2.weight", "lifting.transformer.1.mlp.fc2.bias", "lifting.transformer.2.ln_1.weight", "lifting.transformer.2.ln_1.bias", "lifting.transformer.2.self_attn.Wqkv.weight", "lifting.transformer.2.self_attn.Wqkv.bias", "lifting.transformer.2.ln_2.weight", "lifting.transformer.2.ln_2.bias", "lifting.transformer.2.cross_attn.Wq.weight", "lifting.transformer.2.cross_attn.Wq.bias", "lifting.transformer.2.cross_attn.Wkv.weight", "lifting.transformer.2.cross_attn.Wkv.bias", "lifting.transformer.2.cross_attn.out_proj.weight", "lifting.transformer.2.cross_attn.out_proj.bias", "lifting.transformer.2.ln_3.weight", "lifting.transformer.2.ln_3.bias", "lifting.transformer.2.mlp.fc1.weight", "lifting.transformer.2.mlp.fc1.bias", "lifting.transformer.2.mlp.fc2.weight", "lifting.transformer.2.mlp.fc2.bias", "lifting.transformer.3.ln_1.weight", "lifting.transformer.3.ln_1.bias", "lifting.transformer.3.self_attn.Wqkv.weight", "lifting.transformer.3.self_attn.Wqkv.bias", "lifting.transformer.3.ln_2.weight", "lifting.transformer.3.ln_2.bias", "lifting.transformer.3.cross_attn.Wq.weight", "lifting.transformer.3.cross_attn.Wq.bias", "lifting.transformer.3.cross_attn.Wkv.weight", "lifting.transformer.3.cross_attn.Wkv.bias", "lifting.transformer.3.cross_attn.out_proj.weight", "lifting.transformer.3.cross_attn.out_proj.bias", "lifting.transformer.3.ln_3.weight", "lifting.transformer.3.ln_3.bias", "lifting.transformer.3.mlp.fc1.weight", "lifting.transformer.3.mlp.fc1.bias", "lifting.transformer.3.mlp.fc2.weight", "lifting.transformer.3.mlp.fc2.bias".
Unexpected key(s) in state_dict: "encoder.transformers_cross.0.multihead_attn.in_proj_weight", "encoder.transformers_cross.0.multihead_attn.in_proj_bias", "encoder.transformers_cross.0.multihead_attn.out_proj.weight", "encoder.transformers_cross.0.multihead_attn.out_proj.bias", "encoder.transformers_cross.0.linear1.weight", "encoder.transformers_cross.0.linear1.bias", "encoder.transformers_cross.0.linear2.weight", "encoder.transformers_cross.0.linear2.bias", "encoder.transformers_cross.0.norm2.weight", "encoder.transformers_cross.0.norm2.bias", "encoder.transformers_cross.0.norm3.weight", "encoder.transformers_cross.0.norm3.bias", "encoder.transformers_cross.1.multihead_attn.in_proj_weight", "encoder.transformers_cross.1.multihead_attn.in_proj_bias", "encoder.transformers_cross.1.multihead_attn.out_proj.weight", "encoder.transformers_cross.1.multihead_attn.out_proj.bias", "encoder.transformers_cross.1.linear1.weight", "encoder.transformers_cross.1.linear1.bias", "encoder.transformers_cross.1.linear2.weight", "encoder.transformers_cross.1.linear2.bias", "encoder.transformers_cross.1.norm2.weight", "encoder.transformers_cross.1.norm2.bias", "encoder.transformers_cross.1.norm3.weight", "encoder.transformers_cross.1.norm3.bias", "encoder.transformers_self.0.self_attn.in_proj_weight", "encoder.transformers_self.0.self_attn.in_proj_bias", "encoder.transformers_self.0.self_attn.out_proj.weight", "encoder.transformers_self.0.self_attn.out_proj.bias", "encoder.transformers_self.0.linear1.weight", "encoder.transformers_self.0.linear1.bias", "encoder.transformers_self.0.linear2.weight", "encoder.transformers_self.0.linear2.bias", "encoder.transformers_self.0.norm1.weight", "encoder.transformers_self.0.norm1.bias", "encoder.transformers_self.0.norm2.weight", "encoder.transformers_self.0.norm2.bias", "encoder.transformers_self.1.self_attn.in_proj_weight", "encoder.transformers_self.1.self_attn.in_proj_bias", "encoder.transformers_self.1.self_attn.out_proj.weight", "encoder.transformers_self.1.self_attn.out_proj.bias", "encoder.transformers_self.1.linear1.weight", "encoder.transformers_self.1.linear1.bias", "encoder.transformers_self.1.linear2.weight", "encoder.transformers_self.1.linear2.bias", "encoder.transformers_self.1.norm1.weight", "encoder.transformers_self.1.norm1.bias", "encoder.transformers_self.1.norm2.weight", "encoder.transformers_self.1.norm2.bias", "lifting.transformer.0.multihead_attn.in_proj_weight", "lifting.transformer.0.multihead_attn.in_proj_bias", "lifting.transformer.0.multihead_attn.out_proj.weight", "lifting.transformer.0.multihead_attn.out_proj.bias", "lifting.transformer.0.linear1.weight", "lifting.transformer.0.linear1.bias", "lifting.transformer.0.linear2.weight", "lifting.transformer.0.linear2.bias", "lifting.transformer.0.norm1.weight", "lifting.transformer.0.norm1.bias", "lifting.transformer.0.norm2.weight", "lifting.transformer.0.norm2.bias", "lifting.transformer.0.norm3.weight", "lifting.transformer.0.norm3.bias", "lifting.transformer.0.self_attn.in_proj_weight", "lifting.transformer.0.self_attn.in_proj_bias", "lifting.transformer.1.multihead_attn.in_proj_weight", "lifting.transformer.1.multihead_attn.in_proj_bias", "lifting.transformer.1.multihead_attn.out_proj.weight", "lifting.transformer.1.multihead_attn.out_proj.bias", "lifting.transformer.1.linear1.weight", "lifting.transformer.1.linear1.bias", "lifting.transformer.1.linear2.weight", "lifting.transformer.1.linear2.bias", "lifting.transformer.1.norm1.weight", "lifting.transformer.1.norm1.bias", "lifting.transformer.1.norm2.weight", "lifting.transformer.1.norm2.bias", "lifting.transformer.1.norm3.weight", "lifting.transformer.1.norm3.bias", "lifting.transformer.1.self_attn.in_proj_weight", "lifting.transformer.1.self_attn.in_proj_bias", "lifting.transformer.2.multihead_attn.in_proj_weight", "lifting.transformer.2.multihead_attn.in_proj_bias", "lifting.transformer.2.multihead_attn.out_proj.weight", "lifting.transformer.2.multihead_attn.out_proj.bias", "lifting.transformer.2.linear1.weight", "lifting.transformer.2.linear1.bias", "lifting.transformer.2.linear2.weight", "lifting.transformer.2.linear2.bias", "lifting.transformer.2.norm1.weight", "lifting.transformer.2.norm1.bias", "lifting.transformer.2.norm2.weight", "lifting.transformer.2.norm2.bias", "lifting.transformer.2.norm3.weight", "lifting.transformer.2.norm3.bias", "lifting.transformer.2.self_attn.in_proj_weight", "lifting.transformer.2.self_attn.in_proj_bias", "lifting.transformer.3.multihead_attn.in_proj_weight", "lifting.transformer.3.multihead_attn.in_proj_bias", "lifting.transformer.3.multihead_attn.out_proj.weight", "lifting.transformer.3.multihead_attn.out_proj.bias", "lifting.transformer.3.linear1.weight", "lifting.transformer.3.linear1.bias", "lifting.transformer.3.linear2.weight", "lifting.transformer.3.linear2.bias", "lifting.transformer.3.norm1.weight", "lifting.transformer.3.norm1.bias", "lifting.transformer.3.norm2.weight", "lifting.transformer.3.norm2.bias", "lifting.transformer.3.norm3.weight", "lifting.transformer.3.norm3.bias", "lifting.transformer.3.self_attn.in_proj_weight", "lifting.transformer.3.self_attn.in_proj_bias".
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 27170) of binary: /home/zhenhui/anaconda3/envs/leap/bin/python

So is the ckpt you provided is not compatible with the model I should retrain with flash attention or should I change somewhere else?

Thanks for great work btw!!

Hi,

In the demo_224_real.yaml file I set the use_flash_attn to True and use the kubric model you provided, But I get the error of Missing key here:

envs/leap/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LEAP: Missing key(s) in state_dict: "encoder.transformers_cross.0.ln_1.weight", "encoder.transformers_cross.0.ln_1.bias", "encoder.transformers_cross.0.attn.Wq.weight", "encoder.transformers_cross.0.attn.Wq.bias", "encoder.transformers_cross.0.attn.Wkv.weight", "encoder.transformers_cross.0.attn.Wkv.bias", "encoder.transformers_cross.0.attn.out_proj.weight", "encoder.transformers_cross.0.attn.out_proj.bias", "encoder.transformers_cross.0.ln_2.weight", "encoder.transformers_cross.0.ln_2.bias", "encoder.transformers_cross.0.mlp.fc1.weight", "encoder.transformers_cross.0.mlp.fc1.bias", "encoder.transformers_cross.0.mlp.fc2.weight", "encoder.transformers_cross.0.mlp.fc2.bias", "encoder.transformers_cross.1.ln_1.weight", "encoder.transformers_cross.1.ln_1.bias", "encoder.transformers_cross.1.attn.Wq.weight", "encoder.transformers_cross.1.attn.Wq.bias", "encoder.transformers_cross.1.attn.Wkv.weight", "encoder.transformers_cross.1.attn.Wkv.bias", "encoder.transformers_cross.1.attn.out_proj.weight", "encoder.transformers_cross.1.attn.out_proj.bias", "encoder.transformers_cross.1.ln_2.weight", "encoder.transformers_cross.1.ln_2.bias", "encoder.transformers_cross.1.mlp.fc1.weight", "encoder.transformers_cross.1.mlp.fc1.bias", "encoder.transformers_cross.1.mlp.fc2.weight", "encoder.transformers_cross.1.mlp.fc2.bias", "encoder.transformers_self.0.ln_1.weight", "encoder.transformers_self.0.ln_1.bias", "encoder.transformers_self.0.attn.Wqkv.weight", "encoder.transformers_self.0.attn.Wqkv.bias", "encoder.transformers_self.0.attn.out_proj.weight", "encoder.transformers_self.0.attn.out_proj.bias", "encoder.transformers_self.0.ln_2.weight", "encoder.transformers_self.0.ln_2.bias", "encoder.transformers_self.0.mlp.fc1.weight", "encoder.transformers_self.0.mlp.fc1.bias", "encoder.transformers_self.0.mlp.fc2.weight", "encoder.transformers_self.0.mlp.fc2.bias", "encoder.transformers_self.1.ln_1.weight", "encoder.transformers_self.1.ln_1.bias", "encoder.transformers_self.1.attn.Wqkv.weight", "encoder.transformers_self.1.attn.Wqkv.bias", "encoder.transformers_self.1.attn.out_proj.weight", "encoder.transformers_self.1.attn.out_proj.bias", "encoder.transformers_self.1.ln_2.weight", "encoder.transformers_self.1.ln_2.bias", "encoder.transformers_self.1.mlp.fc1.weight", "encoder.transformers_self.1.mlp.fc1.bias", "encoder.transformers_self.1.mlp.fc2.weight", "encoder.transformers_self.1.mlp.fc2.bias", "lifting.transformer.0.ln_1.weight", "lifting.transformer.0.ln_1.bias", "lifting.transformer.0.self_attn.Wqkv.weight", "lifting.transformer.0.self_attn.Wqkv.bias", "lifting.transformer.0.ln_2.weight", "lifting.transformer.0.ln_2.bias", "lifting.transformer.0.cross_attn.Wq.weight", "lifting.transformer.0.cross_attn.Wq.bias", "lifting.transformer.0.cross_attn.Wkv.weight", "lifting.transformer.0.cross_attn.Wkv.bias", "lifting.transformer.0.cross_attn.out_proj.weight", "lifting.transformer.0.cross_attn.out_proj.bias", "lifting.transformer.0.ln_3.weight", "lifting.transformer.0.ln_3.bias", "lifting.transformer.0.mlp.fc1.weight", "lifting.transformer.0.mlp.fc1.bias", "lifting.transformer.0.mlp.fc2.weight", "lifting.transformer.0.mlp.fc2.bias", "lifting.transformer.1.ln_1.weight", "lifting.transformer.1.ln_1.bias", "lifting.transformer.1.self_attn.Wqkv.weight", "lifting.transformer.1.self_attn.Wqkv.bias", "lifting.transformer.1.ln_2.weight", "lifting.transformer.1.ln_2.bias", "lifting.transformer.1.cross_attn.Wq.weight", "lifting.transformer.1.cross_attn.Wq.bias", "lifting.transformer.1.cross_attn.Wkv.weight", "lifting.transformer.1.cross_attn.Wkv.bias", "lifting.transformer.1.cross_attn.out_proj.weight", "lifting.transformer.1.cross_attn.out_proj.bias", "lifting.transformer.1.ln_3.weight", "lifting.transformer.1.ln_3.bias", "lifting.transformer.1.mlp.fc1.weight", "lifting.transformer.1.mlp.fc1.bias", "lifting.transformer.1.mlp.fc2.weight", "lifting.transformer.1.mlp.fc2.bias", "lifting.transformer.2.ln_1.weight", "lifting.transformer.2.ln_1.bias", "lifting.transformer.2.self_attn.Wqkv.weight", "lifting.transformer.2.self_attn.Wqkv.bias", "lifting.transformer.2.ln_2.weight", "lifting.transformer.2.ln_2.bias", "lifting.transformer.2.cross_attn.Wq.weight", "lifting.transformer.2.cross_attn.Wq.bias", "lifting.transformer.2.cross_attn.Wkv.weight", "lifting.transformer.2.cross_attn.Wkv.bias", "lifting.transformer.2.cross_attn.out_proj.weight", "lifting.transformer.2.cross_attn.out_proj.bias", "lifting.transformer.2.ln_3.weight", "lifting.transformer.2.ln_3.bias", "lifting.transformer.2.mlp.fc1.weight", "lifting.transformer.2.mlp.fc1.bias", "lifting.transformer.2.mlp.fc2.weight", "lifting.transformer.2.mlp.fc2.bias", "lifting.transformer.3.ln_1.weight", "lifting.transformer.3.ln_1.bias", "lifting.transformer.3.self_attn.Wqkv.weight", "lifting.transformer.3.self_attn.Wqkv.bias", "lifting.transformer.3.ln_2.weight", "lifting.transformer.3.ln_2.bias", "lifting.transformer.3.cross_attn.Wq.weight", "lifting.transformer.3.cross_attn.Wq.bias", "lifting.transformer.3.cross_attn.Wkv.weight", "lifting.transformer.3.cross_attn.Wkv.bias", "lifting.transformer.3.cross_attn.out_proj.weight", "lifting.transformer.3.cross_attn.out_proj.bias", "lifting.transformer.3.ln_3.weight", "lifting.transformer.3.ln_3.bias", "lifting.transformer.3.mlp.fc1.weight", "lifting.transformer.3.mlp.fc1.bias", "lifting.transformer.3.mlp.fc2.weight", "lifting.transformer.3.mlp.fc2.bias". Unexpected key(s) in state_dict: "encoder.transformers_cross.0.multihead_attn.in_proj_weight", "encoder.transformers_cross.0.multihead_attn.in_proj_bias", "encoder.transformers_cross.0.multihead_attn.out_proj.weight", "encoder.transformers_cross.0.multihead_attn.out_proj.bias", "encoder.transformers_cross.0.linear1.weight", "encoder.transformers_cross.0.linear1.bias", "encoder.transformers_cross.0.linear2.weight", "encoder.transformers_cross.0.linear2.bias", "encoder.transformers_cross.0.norm2.weight", "encoder.transformers_cross.0.norm2.bias", "encoder.transformers_cross.0.norm3.weight", "encoder.transformers_cross.0.norm3.bias", "encoder.transformers_cross.1.multihead_attn.in_proj_weight", "encoder.transformers_cross.1.multihead_attn.in_proj_bias", "encoder.transformers_cross.1.multihead_attn.out_proj.weight", "encoder.transformers_cross.1.multihead_attn.out_proj.bias", "encoder.transformers_cross.1.linear1.weight", "encoder.transformers_cross.1.linear1.bias", "encoder.transformers_cross.1.linear2.weight", "encoder.transformers_cross.1.linear2.bias", "encoder.transformers_cross.1.norm2.weight", "encoder.transformers_cross.1.norm2.bias", "encoder.transformers_cross.1.norm3.weight", "encoder.transformers_cross.1.norm3.bias", "encoder.transformers_self.0.self_attn.in_proj_weight", "encoder.transformers_self.0.self_attn.in_proj_bias", "encoder.transformers_self.0.self_attn.out_proj.weight", "encoder.transformers_self.0.self_attn.out_proj.bias", "encoder.transformers_self.0.linear1.weight", "encoder.transformers_self.0.linear1.bias", "encoder.transformers_self.0.linear2.weight", "encoder.transformers_self.0.linear2.bias", "encoder.transformers_self.0.norm1.weight", "encoder.transformers_self.0.norm1.bias", "encoder.transformers_self.0.norm2.weight", "encoder.transformers_self.0.norm2.bias", "encoder.transformers_self.1.self_attn.in_proj_weight", "encoder.transformers_self.1.self_attn.in_proj_bias", "encoder.transformers_self.1.self_attn.out_proj.weight", "encoder.transformers_self.1.self_attn.out_proj.bias", "encoder.transformers_self.1.linear1.weight", "encoder.transformers_self.1.linear1.bias", "encoder.transformers_self.1.linear2.weight", "encoder.transformers_self.1.linear2.bias", "encoder.transformers_self.1.norm1.weight", "encoder.transformers_self.1.norm1.bias", "encoder.transformers_self.1.norm2.weight", "encoder.transformers_self.1.norm2.bias", "lifting.transformer.0.multihead_attn.in_proj_weight", "lifting.transformer.0.multihead_attn.in_proj_bias", "lifting.transformer.0.multihead_attn.out_proj.weight", "lifting.transformer.0.multihead_attn.out_proj.bias", "lifting.transformer.0.linear1.weight", "lifting.transformer.0.linear1.bias", "lifting.transformer.0.linear2.weight", "lifting.transformer.0.linear2.bias", "lifting.transformer.0.norm1.weight", "lifting.transformer.0.norm1.bias", "lifting.transformer.0.norm2.weight", "lifting.transformer.0.norm2.bias", "lifting.transformer.0.norm3.weight", "lifting.transformer.0.norm3.bias", "lifting.transformer.0.self_attn.in_proj_weight", "lifting.transformer.0.self_attn.in_proj_bias", "lifting.transformer.1.multihead_attn.in_proj_weight", "lifting.transformer.1.multihead_attn.in_proj_bias", "lifting.transformer.1.multihead_attn.out_proj.weight", "lifting.transformer.1.multihead_attn.out_proj.bias", "lifting.transformer.1.linear1.weight", "lifting.transformer.1.linear1.bias", "lifting.transformer.1.linear2.weight", "lifting.transformer.1.linear2.bias", "lifting.transformer.1.norm1.weight", "lifting.transformer.1.norm1.bias", "lifting.transformer.1.norm2.weight", "lifting.transformer.1.norm2.bias", "lifting.transformer.1.norm3.weight", "lifting.transformer.1.norm3.bias", "lifting.transformer.1.self_attn.in_proj_weight", "lifting.transformer.1.self_attn.in_proj_bias", "lifting.transformer.2.multihead_attn.in_proj_weight", "lifting.transformer.2.multihead_attn.in_proj_bias", "lifting.transformer.2.multihead_attn.out_proj.weight", "lifting.transformer.2.multihead_attn.out_proj.bias", "lifting.transformer.2.linear1.weight", "lifting.transformer.2.linear1.bias", "lifting.transformer.2.linear2.weight", "lifting.transformer.2.linear2.bias", "lifting.transformer.2.norm1.weight", "lifting.transformer.2.norm1.bias", "lifting.transformer.2.norm2.weight", "lifting.transformer.2.norm2.bias", "lifting.transformer.2.norm3.weight", "lifting.transformer.2.norm3.bias", "lifting.transformer.2.self_attn.in_proj_weight", "lifting.transformer.2.self_attn.in_proj_bias", "lifting.transformer.3.multihead_attn.in_proj_weight", "lifting.transformer.3.multihead_attn.in_proj_bias", "lifting.transformer.3.multihead_attn.out_proj.weight", "lifting.transformer.3.multihead_attn.out_proj.bias", "lifting.transformer.3.linear1.weight", "lifting.transformer.3.linear1.bias", "lifting.transformer.3.linear2.weight", "lifting.transformer.3.linear2.bias", "lifting.transformer.3.norm1.weight", "lifting.transformer.3.norm1.bias", "lifting.transformer.3.norm2.weight", "lifting.transformer.3.norm2.bias", "lifting.transformer.3.norm3.weight", "lifting.transformer.3.norm3.bias", "lifting.transformer.3.self_attn.in_proj_weight", "lifting.transformer.3.self_attn.in_proj_bias". ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 27170) of binary: /home/zhenhui/anaconda3/envs/leap/bin/python

So is the ckpt you provided is not compatible with the model I should retrain with flash attention or should I change somewhere else?

Thanks for great work btw!!

Hi, yes, the released model is trained without flashattention.