lucidrains/x-transformers

Can't pickle <class 'torch.nn.attention._SDPBackend'>: attribute lookup _SDPBackend on torch.nn.attention failed

Closed this issue · 3 comments

It seems that torch.save fails on some models:

Example:

  • Using Torch==2.3.0+cu118
from x_transformers import XTransformer
import torch
model = XTransformer(
    dim = 512,
    enc_num_tokens=200,
    return_tgt_loss = True,
    enc_depth = 3,
    enc_heads = 8,
    enc_max_seq_len = 10,
    dec_num_tokens = 200,
    dec_depth = 3,
    dec_heads = 8,
    dec_max_seq_len = 200 + 1,
    tie_token_emb = True,      # tie embeddings of encoder and decoder
    enc_rotary_pos_emb = True,
    dec_rotary_pos_emb = True,
)

torch.save(model, 'xtransformer_test.pth')
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <class 'torch.nn.attention._SDPBackend'>: attribute lookup _SDPBackend on torch.nn.attention failed

Ah, sorry, I should have checked the syntax of my code before creating an issue.

@guillaumeguy no problem

are you on the ML team at Lyft?

Sorry, I have encountered the same problem, but I encountered it during ddp training, I don't know what the reason is...