CyberZHG/keras-multi-head

Fitting MultiHeadAttention in memory for long sequences

Jsevillamol opened this issue · 1 comments

I am trying to train my own sequence tagging model based on this repository implementation of MultiHeadAttention

import keras.layers as ll
from keras import Model
from keras_pos_embd import TrigPosEmbedding
from keras_multi_head import MultiHeadAttention

inputs = ll.Input(shape=(None,))
x = ll.Embedding(10000, 1024)(inputs)
x = TrigPosEmbedding(mode='add')(x)
x = MultiHeadAttention(head_num=8)(x)
x = ll.Dense(units = 512, activation='relu')(x)
x = ll.Dense(units = 4, activation='softmax')(x)
outputs = x
model = Model(inputs, outputs)
model.summary()

I have one big problem. The sequences in my training set are quite long (length upper bound by 20000), and when I attempt to train it I get an OOM.

The OOM happens when trying to allocate a [16, 20000, 20000] tensor. If my calculations are correct, just storing this vector would take >150 GB of RAM!

I was wondering if you have any suggestions on how to modify the code to make it work in a more serialized way, only loading in memory a context of the length specified by a custom parameter.

I tried going to a lower level with keras_self_attention.SeqSelfAttention and the configurable attention width, but in the end it would still try to allocate a very big tensor to my GPU.

PD: Awesome repo!

stale commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.