Training took too long time with pre-trained embedding BERT layer
beubeu13220 opened this issue · 1 comments
Hello everyone,
We are trying to integrate pre-trained BERT embedding into our TFRS model.
Our model is based on the same definition as https://www.tensorflow.org/recommenders/examples/basic_retrieval.
class RetrievalUserModel(tf.keras.Model):
def __init__(self: Self, users_vocab: np.ndarray) -> None:
super().__init__()
self.user_embedding: tf.keras.Sequential = tf.keras.Sequential(
[
tf.keras.layers.StringLookup(vocabulary=users_vocab, mask_token=None),
tf.keras.layers.Embedding(len(users_vocab) + 1, 32),
]
)
def call(self: Self, features: Dict[str, tf.Tensor]) -> tf.Tensor:
return self.user_embedding(features["user_id"])
class RetrievalItemModel(tf.keras.Model):
def __init__(
self: Self, items_vocab: np.ndarray, categories_vocab: np.ndarray
) -> None:
super().__init__()
self.item_embedding: tf.keras.Sequential = tf.keras.Sequential(
[
tf.keras.layers.StringLookup(vocabulary=items_vocab, mask_token=None),
tf.keras.layers.Embedding(len(items_vocab) + 1, 32),
]
)
self.categories_embedding: tf.keras.Sequential = tf.keras.Sequential(
[
tf.keras.layers.StringLookup(vocabulary=categories_vocab, mask_token=None),
tf.keras.layers.Embedding(len(categories_vocab) + 1, 32),
]
)
self.description_embedding: tf.keras.layers.Layer = BertEmbeddingLayer()
def call(self: Self, features: Dict[str, tf.Tensor]) -> tf.Tensor:
return tf.concat(
[
self.item_embedding(features["item_id"]),
self.categories_embedding(features["category_id"]),
self.description_embedding(features["item_description"]),
],
axis=1,
)
class RetrievalModel(tfrs.models.Model):
def __init__(
self: Self,
users_vocab: np.ndarray,
items_vocab: np.ndarray,
categories_vocab: np.ndarray,
items_candidates: tf.Tensor,
batch_size: int = 128,
) -> None:
super().__init__()
self.batch_size = batch_size
self.user_model = tf.keras.Sequential(
[RetrievalUserModel(users_vocab), tf.keras.layers.Dense(32)]
)
self.item_model = tf.keras.Sequential(
[
RetrievalItemModel(items_vocab, categories_vocab),
tf.keras.layers.Dense(32),
]
)
self.task = tfrs.tasks.Retrieval(
metrics=tfrs.metrics.FactorizedTopK(
candidates=items_candidates.batch(batch_size).map(self.item_model),
),
)
def compute_loss(
self: Self, features: Dict[str, tf.Tensor], training: bool = False
) -> tf.Tensor:
user_embedding = self.user_model(
{
"user_id": features["user_id"],
}
)
item_embedding = self.item_model(
{
"item_id": features["item_id"],
"category_id": features["category_id"],
"item_description": features["item_description"],
}
)
return self.task(user_embedding, item_embedding)
Where BertEmbeddingLayer
is defined as:
class BertEmbeddingLayer(tf.keras.layers.Layer):
def __init__(self: Self, use_normalize: bool = True) -> None:
super().__init__()
self.use_normalize = use_normalize
self.preprocessing_layer = hub.KerasLayer(
"https://tfhub.dev/jeongukjae/distilbert_multi_cased_preprocess/2",
name="preprocessing",
)
self.encoder_layer = hub.KerasLayer(
"https://tfhub.dev/jeongukjae/distilbert_multi_cased_L-6_H-768_A-12/1",
trainable=False,
name="BERT",
)
def call(self: Self, inputs: tf.Tensor) -> tf.Tensor:
encoder_inputs = self.preprocessing_layer(inputs)
sequence_output = self.encoder_layer(encoder_inputs)["sequence_output"]
pooled_output = tf.keras.layers.GlobalAveragePooling1D()(
sequence_output, encoder_inputs["input_mask"]
)
if self.use_normalize:
pooled_output = self.normalize(pooled_output)
return pooled_output
def normalize(self: Self, embeddings: tf.Tensor) -> tf.Tensor:
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
return embeddings
We decided to use BERT in the training step like that we don't have to compute
embeddings at the time of inference.
We run the following training:
cached_train = tf.data.Dataset.from_tensor_slices(dict(train_df[columns])).batch(batch_size)
cached_test = tf.data.Dataset.from_tensor_slices(dict(test_df[columns])).batch(batch_size)
model = RetrievalModel()
model.compile(
optimizer=tf.keras.optimizers.legacy.Adagrad(0.1)
)
model.fit(
cached_train,
epochs=3,
validation_data=cached_test,
batch_size=128,
)
cached_train has 251 429 rows & cached_test 51 748 rows.
When we train this model with a batch_size=4096 without BertEmbeddingLayer
.
Our training takes less than 30min on an AWS ml.g4dn.2xlarge instance (1 GPU 16GB).
Once BertEmbeddingLayer
is included, it is impossible to train the model with batch_size=4096: OOM Killed.
With a batch_size=2048, the tensorflow ETA is estimated at 500h.
Using a more powerful ml.p3.2xlarge machine (1 GPU GPUs-V100 & 8cpu) does not reduce ETA.
We also tested performing the tokenizer operation before the fit step but this did not improve ETA.
Using tokenizers and hugging-face encoders does nothing better.
We are left with an option, which is not preferred because we want to have embedding at inference, which is to
calculate the embedding outside the model and use the result as features model like bellow:
embedding_layer = layers.Embedding(
input_dim=xxx,
output_dim=xxx,
weights=[your embedding],
trainable=True,
)
We'd really appreciate your help,
Do you have any suggestions for improvements to obtain a suitable ETA? Or advice on the correct implementation of BERT in tfrs ?
That's the cruel truth, Bert is expensive to train. Basically, you have 3 choices:
- Try a smaller Bert, like small-bert in this page, and this one. We have tried it before, I kind of remember it's trainable but slow
- Use a higher config GPU machine, like 4 GPU p3.8xlarge, or even try distributed training in multiple machines (The tuning work could be tricky)
- Pre-generate all the embeddings and load the embedding weights to an embedding layer. This can separate the cost of train bert on the fly. Same idea as your listed option. You can also fine-tuned the bert on your dataset before generate the embeddings, this generally can have better performance