david-thrower/cerebros-core-algorithm-alpha

tandem-embeddings-with-freezable-weights

Opened this issue · 0 comments

Kind of issue: The botteck on the tandem embeddings may be that the embedding converges to an optima well before dense layers do. Consequently, the embedding gradients will zero out. This will cascade to zero out all the other gradients due to the chain rule.

A solution to try may look like this:

import tensorflow as tf
import numpy as np

class TemporalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embedding_dim, **kwargs):
        super(TemporalEmbedding, self).__init__(trainable=True)
        self.compute_gradient_for_n_epochs = 7
        self.train_counter = 0
        self.embedding_1 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
        self.embedding_2 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
        self.embedding_2.trainable = False
    def set_compute_gradient_for_n_epochs(self, n: int):
        self.compute_gradient_for_n_epochs = n
        print(f"Training this layer for only {self.compute_gradient_for_n_epochs} epochs")
    def call(self, inputs):
        print(f"starting state: {self.train_counter}")
        if self.train_counter < self.compute_gradient_for_n_epochs:
            print(f"Training weights for epoch {self.train_counter}")
            self.train_counter += 1
            return self.embedding_1(inputs)
        elif self.train_counter == self.compute_gradient_for_n_epochs:
            print(f"Setting trained weights to untrainable model (1) {self.train_counter}")
            self.train_counter += 1
            weights_0 =  self.embedding_1.get_weights()
            self.embedding_2.set_weights(weights_0)
            print("Returning weights from untrainable model")
            return self.embedding_2(inputs)
        else:
            print(f"Returning weights from untrainable model (2) {self.train_counter}")
            self.train_counter += 1
            return self.embedding_2(inputs)


input_layer = tf.keras.layers.Input(shape=(100,))
temporal_embedding_layer = TemporalEmbedding(vocab_size=10000, embedding_dim=64, input_length=10)
temporal_embedding_layer.set_compute_gradient_for_n_epochs(n=3)
temporal_embedding_layer_called = temporal_embedding_layer(input_layer)
flat = tf.keras.layers.Flatten()(temporal_embedding_layer_called)
output_layer = tf.keras.layers.Dense(10, activation='softmax')(flat)
model2 = tf.keras.Model(inputs=input_layer, outputs=output_layer)
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model2.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])


x_train = np.random.randint(10000, size=(200,100))
y_train = np.random.randint(2, size=(200,10))

model2.fit(x_train, y_train, epochs=20, batch_size=32)

Suggested Labels (If you don't know, that's ok): kind/enhancement