tandem-embeddings-with-freezable-weights
Opened this issue · 0 comments
david-thrower commented
Kind of issue: The botteck on the tandem embeddings may be that the embedding converges to an optima well before dense layers do. Consequently, the embedding gradients will zero out. This will cascade to zero out all the other gradients due to the chain rule.
A solution to try may look like this:
import tensorflow as tf
import numpy as np
class TemporalEmbedding(tf.keras.layers.Layer):
def __init__(self, vocab_size, embedding_dim, **kwargs):
super(TemporalEmbedding, self).__init__(trainable=True)
self.compute_gradient_for_n_epochs = 7
self.train_counter = 0
self.embedding_1 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
self.embedding_2 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
self.embedding_2.trainable = False
def set_compute_gradient_for_n_epochs(self, n: int):
self.compute_gradient_for_n_epochs = n
print(f"Training this layer for only {self.compute_gradient_for_n_epochs} epochs")
def call(self, inputs):
print(f"starting state: {self.train_counter}")
if self.train_counter < self.compute_gradient_for_n_epochs:
print(f"Training weights for epoch {self.train_counter}")
self.train_counter += 1
return self.embedding_1(inputs)
elif self.train_counter == self.compute_gradient_for_n_epochs:
print(f"Setting trained weights to untrainable model (1) {self.train_counter}")
self.train_counter += 1
weights_0 = self.embedding_1.get_weights()
self.embedding_2.set_weights(weights_0)
print("Returning weights from untrainable model")
return self.embedding_2(inputs)
else:
print(f"Returning weights from untrainable model (2) {self.train_counter}")
self.train_counter += 1
return self.embedding_2(inputs)
input_layer = tf.keras.layers.Input(shape=(100,))
temporal_embedding_layer = TemporalEmbedding(vocab_size=10000, embedding_dim=64, input_length=10)
temporal_embedding_layer.set_compute_gradient_for_n_epochs(n=3)
temporal_embedding_layer_called = temporal_embedding_layer(input_layer)
flat = tf.keras.layers.Flatten()(temporal_embedding_layer_called)
output_layer = tf.keras.layers.Dense(10, activation='softmax')(flat)
model2 = tf.keras.Model(inputs=input_layer, outputs=output_layer)
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model2.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
x_train = np.random.randint(10000, size=(200,100))
y_train = np.random.randint(2, size=(200,10))
model2.fit(x_train, y_train, epochs=20, batch_size=32)
Suggested Labels (If you don't know, that's ok): kind/enhancement