Should the mask option to AttentionLayers(boolean) increase memory?
blasscoc opened this issue · 1 comments
Should the mask option in AttentionLayers be using up so much memory?
Behavior:
When I use the mask variant, the memory consumption is 6326MiB, without the mask, the memory consumption is 1364MiB.
6326MiB with the mask, 1364MiB without the mask.
import torch
from x_transformers.x_transformers import AttentionLayers
attn_config = {
"dim": 128,
"depth": 4,
"heads": 6,
"ff_mult": 8,
"attn_flash": True
}
encoder = AttentionLayers(**attn_config)
encoder.cuda()
x = torch.randn(10, 4000, 128)
mask = torch.ones(10, 4000).bool()
x = x.to('cuda')
mask = mask.to('cuda')
while 1:
with torch.no_grad():
#output = encoder(x, mask=mask)
output = encoder(x)
I followed the memory consumption through to the call to scaled_dot_product_attention in "flash_attn" function. I saw that the memory increases in that function. I verified that the dtype of mask, if initially bool remained so throughout, and that there weren't significant copies being made.
The memory consumption will scale like the sequence length squared, times the number of heads. And so can become quite large. Interestingly casting from bool to float on line 238 of attend.py resulted in only a modest increase in memory consumption, which was unexpected.