tallamjr/astronet

Stack Encoding Layers to be in line with Vaswani paper

Closed this issue · 2 comments

In Vaswani et. al, they stack there Encoding blocks x N, where N = 6 in the paper.

This can be seen below:
image

In the Tensorflow documentation and guides on Transformers they implement this as follows:

class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)


    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)

Notice the lines:

    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

and

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)

This could perhaps be used with TransformerBlock in place of EncoderLayer

A diff of a possible implementation could look something like this:

diff --git a/astronet/t2/model.py b/astronet/t2/model.py
index 117b605..79c0013 100644
--- a/astronet/t2/model.py
+++ b/astronet/t2/model.py
@@ -12,7 +12,7 @@ class T2Model(keras.Model):
     num_heads --> Number of attention heads
     ff_dim    --> Hidden layer size in feed forward network inside transformer
     """
-    def __init__(self, input_dim, embed_dim, num_heads, ff_dim, num_filters, num_classes, **kwargs):
+    def __init__(self, input_dim, embed_dim, num_heads, ff_dim, num_filters, num_classes, num_layers=6, **kwargs):
         super(T2Model, self).__init__()
         self.input_dim      = input_dim
         self.embed_dim      = embed_dim
@@ -24,7 +24,10 @@ class T2Model(keras.Model):
 
         self.embedding      = ConvEmbedding(num_filters=self.num_filters, input_shape=input_dim)
         self.pos_encoding   = PositionalEncoding(max_steps=self.sequence_length, max_dims=self.embed_dim)
-        self.encoder        = TransformerBlock(self.embed_dim, self.num_heads, self.ff_dim)
+
+        self.num_layers     = num_layers
+        self.encoder        = [TransformerBlock(self.embed_dim, self.num_heads, self.ff_dim)
+                                for _ in range(self.num_layers)]
         # TODO : Branch off here, outputs_2, with perhaps Dense(input_dim[1]), OR vis this layer since
         # output should be: (batch_size, input_seq_len, d_model), see:
         # https://github.com/cordeirojoao/ECG_Processing/blob/master/Ecg_keras_v9-Raphael.ipynb
@@ -34,11 +37,14 @@ class T2Model(keras.Model):
         self.dropout2       = layers.Dropout(0.1)
         self.classifier     = layers.Dense(self.num_classes, activation="softmax")
 
-    def call(self, inputs, training=None):
+    def call(self, inputs, training, mask):
 
         x = self.embedding(inputs)
         x = self.pos_encoding(x)
-        x = self.encoder(x)
+
+        for layer in range(self.num_layers):
+            x = self.encoder[layer](x, training, mask)
+
         x = self.pooling(x)
         if training:
             x = self.dropout1(x, training=training)

Note, mask may not be required since with the current implementation of multi-head attention in astronet.t2.attention.py there is no reference to mask. This is only used in the other (currently not integrated) implementation of multi-head attention found in astronet.t2.multihead_attention.py`