ctlearn-project/ctlearn

Understanding the CNN-RNN model

Closed this issue · 3 comments

I am trying to understand how the current cnn-rnn model processes stuff.

If I am understanding correctly, to the model three sets of features arrives:

  • A list of images, each of which could either be true images from triggered telescopes or zero arrays as padding
  • A list of numbers indicating which images are from triggered telescopes
  • An auxiliary input which is discarded

Those three parameters have always fixed length, which is equal to num_telescopes

To the LSTM we feed the whole list of images, specifying that the length of the sequence is equal to the number of triggered telescopes. I think this should not happen in the current version: the length of the sequence is not equal to the number of triggered telescopes, but equal to num_telescopes, since non triggered telescopes are also fed to the LSTM.

Then the outputs of each LSTM cell (whose number is equal to num_telescopes) are concatenated and fed to a dense layer.

Is this correct?

I think your understanding is mostly right, with some minor clarifications:

A list of numbers indicating which images are from triggered telescopes

Yup, this a binary vector which indicates which images (telescopes) in the batch are real (triggered) and not padded. It has a 1 in the position of each triggered image and 0 for each padded image, but it is not used directly in the cnn-rnn model (only to calculate sequence_length), because we already sort the sequence of images during processing so that all the triggered images go to the front of the sequence and the padding to the back.

if self.sort_images_by == "trigger":
# Sort the images, triggers, and grouped auxiliary inputs by
# trigger, listing the triggered telescopes first
images, triggers, aux_inputs = map(list,
zip(*sorted(zip(images, triggers, aux_inputs), reverse=True, key=itemgetter(1))))
elif self.sort_images_by == "size":
# Sort images by size (sum of charge in all pixels) from largest to smallest
images, triggers, aux_inputs = map(list,
zip(*sorted(zip(images, triggers, aux_inputs), reverse=True, key=lambda x: np.sum(x[0]))))
So we don't really care about the telescope ordering any more, all we need to know is how many are triggered in total (the sum of telescope_triggers).

An auxiliary input which is discarded

The initial plan was to use the auxiliary inputs like the telescope position by concatenating them to the image embeddings, but for now this is correct, we don't use them.

Those three parameters have always fixed length, which is equal to num_telescopes

Yup the inputs are all of fixed shape (this is required):

Telescope images: [batch_size, num_tel, width, length, depth]
Telescope triggers: [batch_size, num_tel]
Telescope aux inputs: [batch_size, num_tel, num_aux_inputs]

As you noted, num_tel here is the total (max) number of telescopes and is constant. We only use the actual number of triggered telescopes for sequence_length (see below).

To the LSTM we feed the whole list of images,

First, the images are each fed into a set of convolutional and pooling layers (the "conv block" in order to process them into a 1-D vector embedding representing the image. It is this sequence of embeddings (+ auxiliary input) for each telescope which is the input to the LSTM. The weights of the conv block are identical for all telescopes and can also be pre-trained.

output = cnn_block(tf.gather(telescope_data, telescope_index),
params=params, reuse=reuse, training=training)

specifying that the length of the sequence is equal to the number of triggered telescopes. I think this should not happen in the current version: the length of the sequence is not equal to the number of triggered telescopes, but equal to num_telescopes, since non triggered telescopes are also fed to the LSTM.

You're right about the actual length of the sequence (because we pad the original image sequence up to the total number of telescopes with all-zero images.) This is done because TF's dynamic_rnn takes an input tensor of constant size (in this case the [batch_size, total num_telescopes, embedding length + num auxiliary params]).

However, dynamic_rnn also takes an argument sequence_length telling it how many of the elements in the input sequence are real (not padding). It uses this to essentially ignore all of the padded images (the weights/state are frozen once it reaches the true end of the sequence). See https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn.

As mentioned above, one key thing we need to do for the cnn rnn model is sort the image sequence so that all of the triggered images are at the front and the padding at the end. We also experimented with sorting them by image size (the total charge) from largest to smallest, which also obviously puts all the zero-padded images to the back. Unfortunately sorting them this way (either by trigger or by size) destroys any information contained in the order of the inputs, but we found in practice that this didn't seem to be critical (this was what was done in the paper by the HESS deep learning group). It would be interesting to consider if there are other better ways to handle the input.

Then the outputs of each LSTM cell (whose number is equal to num_telescopes) are concatenated and fed to a dense layer.

The output of the LSTM layer is a matrix of shape [batch_size, max_num_tel, output_size]. So each LSTM cell is essentially outputting [batch_size, output_size], as each cell receives 1 telescope input. We then flatten across the sequence to shape [batch_size, max_num_tel x output_size], so basically grouping all of the output from all telescopes into a single vector for each event, and then feed this into the dense layers.

@bryankim96 I notice I am confused. How does the dynamic_rnn currently know which images are padding and which are real? After going through the CNN, the bias terms makes it so the padding images turn into things that are not necessarily null vectors.

The dynamic_rnn assumes that the real inputs are at the front and any padding is at the end (this is why we had to sort the images by trigger). For example, if num_telescopes is 10 but for a particular example sequence_length is 5 (imagine the batch size is 1), the LSTM layer will receive an array 'inputs' of shape [1, 10, num_units_embedding]. It will treat inputs[: , :5, :] as real images and ignore inputs[: , 5:, :].

You're 100% right that the embeddings for the blank images will be non-zero due to bias terms (unlike the original images). I don't think it causes any problems though, because they are completely ignored anyways (the weights/internal state of the LSTM is just copied once it goes past sequence_length).