try-alex-custom-embedding-with-no-bnorm
Opened this issue · 0 comments
david-thrower commented
Kind of issue: Enhancement
Additional context: Alex developed a custom enbedding:
class CustomEmbedding(tf.keras.layers.Layer):
def __init__(self, input_dim, output_dim, **kwargs):
super(CustomEmbedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
def build(self, input_shape):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
initializer='uniform',
trainable=True,
name='embeddings'
)
self.scaling = self.add_weight(
shape=input_shape[-1],
initializer='ones',
trainable=True,
name='scaling'
)
super(CustomEmbedding, self).build(input_shape)
def gaussian_one_hot(self, x, depth, sigma=0.1):
y = tf.range(depth, dtype=tf.float32)
x = tf.expand_dims(x, -1) # Add an extra dimension for broadcasting
vec = tf.exp(-tf.square(y - x) / (2 * tf.square(sigma)))
vec = vec / tf.reduce_sum(vec, axis=-1, keepdims=True)
return vec
def call(self, inputs):
scaled = tf.nn.softmax(inputs) * tf.exp(self.scaling)
#tf.print(scaled)
retrieve = self.gaussian_one_hot(scaled, self.input_dim)
embedded = tf.einsum('bik,kj->bij', retrieve, self.embeddings)
return embedded
def compute_output_shape(self, input_shape):
return input_shape+(self.output_dim,)
I want to try this without the BatchNormalization layer and see if it will fix the issue with the gradients.