lucidrains/x-transformers

High memory usage compared to Huggingface and Autocast has no effect?

LarsHill opened this issue · 3 comments

I actually encountered a similar scenario.

The standard Huggingface bert-base-cased model trained with 16 bit mixed precision (using pytorch-lightning), a vocab size of 100K and a seq len of 1024 uses around 34GB of Memory with a batch size=8 (on Nvidia A100). If I switch to full 32 bit precision the RAM usage almost doubles to 67GB of Memory which is expected.

However, if I use the x-transformers "Bert-like" implementation (mimicking the Huggingface config)

self.bert = TransformerWrapper(
            num_tokens=100_000,
            max_seq_len=1024,
            emb_dropout=0.1,
            tie_embedding=True,
            attn_layers=Encoder(
                dim=768,
                depth=12,
                heads=12,
                attn_flash=False,
                layer_dropout=0.1,  # stochastic depth - dropout entire layer
                attn_dropout=0.1,  # dropout post-attention
                ff_dropout=0.1,  # feedforward dropout
                use_abs_pos_emb=True,
            ),
        )

the memory usage does not change if I switch between 16-mixed and 32 precision. The overall usage (same batch size and hardware) remains at a constant 52GB, which is substantially higher than the HF model with 34GB.
Why does the precision setting of the lightning trainer not affect the x-transformers implementation?

I would love to use the x-transformer implementation due to the large amount of new features. However, I am wondering where these significant GPU RAM differences come from? And why does torch.autocast, which I think is used by lightning under the hood show no effect?

Originally posted by @LarsHill in #35 (comment)

I believe x-transformers uses the PyTorch attention layer which given certain conditions it defaults to their memory efficient attention implementation. Maybe an unrelated note, but recently I've started having some memory issues after update x-transfomers on the same model that never has given me issues. I'm trying to identify where the memory leak is but it is a bit tough to pinpoint.