lucidrains/x-transformers

Issue with torch.compile

scopello opened this issue · 4 comments

Hi @lucidrains,

I am trying to use torch.compile() with a model that wraps two x-transformer Encoders. When I run the following minimal example:

import torch
import torch.nn as nn
from x_transformers import Encoder

class MyModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder_1 = Encoder(dim=32, depth=2, heads=2)
        self.encoder_2 = Encoder(dim=32, depth=2, heads=2)

    def forward(self, x_1, x_2):
        out_1 = self.encoder_1(x_1)
        out_2 = self.encoder_2(x_2)
        return torch.cat([out_1, out_2], 1)

model = MyModel().cuda()

seq_len_1 = 8
seq_len_2 = 16
x_1 = torch.randn([1, seq_len_1, 32]).cuda()
x_2 = torch.randn([1, seq_len_2, 32]).cuda()
# Compile the model.
model = torch.compile(model)
out = model(x_1, x_2)

I get error:
TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(1, s0, 128), grad_fn=), 'b n (h d) -> b h n d'), **{'h': 2}): unhashable type: non-singleton SymInt
Which comes from:
https://github.com/lucidrains/x-transformers/blob/2a0ec67fbdad18d2bd5f8bf3d9bc20e705a58a6b/x_transformers/x_transformers.py#L801

Surprisingly, the model compiles successfully if I set seq_len_2 = seq_len_1, but I don't know why.

I am using einops 0.7.0rc1 and pytorch 2.1.0

Thanks!

Seems to work fine on A100, but not H100.

ah nice, yea that seems like an einops / pytorch specific error, but not entirely sure

what is your use-case btw? that's a really interesting network

oh, are you doing two towers architecture?

Thanks! This is for model that requires encoders for 2 different modalities. Btw, would you expect any significant speedup by using torch.compile if flash attention is enabled?