fidler-lab/polyrnn-pp-pytorch

End of sequence token in the RNN.

qTipTip opened this issue · 7 comments

Hi! First of all, thanks for a very interesting article!

Hope you don't mind me asking.
I am currently learning about recurrent neural networks, and I was wondering whether you could elaborate on how the EOS-token is implemented. You say that you one-hot-encode vertices as a DxD + 1 dimensional vector, where the last element signifies that the polygon is closed, i.e., that the current predicted vertex is equal to the first predicted vertex.

  1. How is this implemented? Do you do a check of equality between current predicted and first predicted at every time step, and set the bit accordingly?

  2. What happens after the EOS-token has been reached? Does the network keep predicting vertices until the max_sequence_length number of vertices has been predicted, or does the network stop in some way? In the former case, how do you "discard" the extra vertices generated after the EOS-token?

Cheers in advance!

Hi, @qTipTip. In the paper, the output represented as a one-hot encoding of (D*D)+1 elements, eg: when MLE, the D goes to 28. so, the end-token can be stored at 785th element. but the one confused me is that, when process the output of the script 'genetate_annotation.py'. the author use grid_size**2 as the end-token of the predicted polygons like this:

def get_masked_poly(poly, grid_size):
    """
    NOTE: Numpy function

    Given a polygon of shape (N,), finds the first EOS token
    and masks the predicted polygon till that point
    """
    if np.max(poly) == grid_size**2:
        ## If there is an EOS in the prediction
        length = np.argmax(poly)    # calc the index of max-value of poly
        poly = poly[:length]

    return poly

I don't know why the end-token is the same grid_size**2 with different test image.
Hi, @amlankar. Can you help me figure out the question how the end token works when prediction step.
Hope your reply with sciencerly thanks.

Hi @Jacoobr!

Since Python is zero-indexed, the element at index grid_size**2 is in fact the 785th element (with D = 28). I guess the function you posted takes an array of length N, checks whether the index 784 (the 785th element) is present in the polygon, i.e., the end of sequence token. If this index is present, then keep only the elements preceding it in the array.

Hi, @qTipTip .Thanks for your reply. I mean np.max(poly) returns the max value of poly, not the index of max value. Why for each test image, the max value of vertexs all are 754 (end-token). BTW, When i predicte one polygon of a test image, i can set the length (mybe smaller than 100) of polygon by changing the parameter '"max_poly_len": ,' in mle.json script.

If I understand you correctly, you are wondering why they use np.max(poly) and not np.argmax(poly)?
Recall that the polygon is represented in terms of the raveled indices of the vertices in a flattened quantized output grid. That is, if D = 4 for instance, there are 16 possible vertex locations, represented by the numbers 0 through 15. Assume that max_poly_len = 10. A predicted polygon of length 6 (for instance) may therefore be represented as follows:

polygon = [3, 7, 12, 4, 8, 9, 16, 7, 3, 4]

Note that the end-of-sequence token (polygon[6]) was emmited by the network. Therefore, the indices following the end-of-sequence-token are discarded:

length = np.argmax(polygon) # => 6
poly = poly[:length] # => [3, 7, 12, 4, 8, 9]

The unraveled indices in the output grid can then be retained by

xy = np.unravel_index(poly, (D, D)) # => [(3, 0), (3, 1), (0, 3), (0, 1), (0, 2), (1, 2)]

Hi @qTipTip @Jacoobr,

Sorry for the late reply. To answer the initial question:

  1. At test time, we find the first EOS token and only take the polygon till that location. Since the poly is stored as a list of indices (as mentioned above), np.max works fine.

  2. We do not stop running the RNN right now if EOS has been reached, but in principle it is possible and should be done in a performance centric case. To discard, we do what I described in 1, using the get_masked_poly function.

Hi, @qTipTip @amlankar. Thank U guys. I got it.

@amlankar Thanks! I seem to have figured it out, but it was nice with a confirmation! Thanks again for the great work!