google-deepmind/code_contests

AlphaCode Model Implementation Questions

darien-schettler opened this issue · 0 comments

@davidhchoi - I saw you mentioned in this closed issue you would be able to answer questions regarding implementation details of AlphaCode. I am in the process of creating a model based on the architecture described in the AlphaCode paper and, initially, I want to create something as close to the original implementation as possible.

To that end, I had a few questions about the architecture and training details and was hoping to get them answered here.

  1. The Tokenizer.
  • In the paper it mentions that a SentencePiece tokenizer with a vocab size of 8000 is used.
  • Is the tokenization 'unigram' style or something else? (bpe, subword, etc)
  • Are there any other parameters to the tokenizer I should be aware of?
  1. Padding/truncation of encoder/decoder inputs.
  • The paper that the pivot point is sampled from files and the text before is the input to the encoder while the text after the pivot is input to the decoder.
  • I know the encoder takes 1536 tokens and the decoder takes 768... so how is this padded/truncated?
  • My intuition is that the encoder would be padded/truncated at the beginning and the decoder would be padded/truncated at the end. This would ensure that the encoder and decoder inputs are continuous through the pivot point.
  • i.e. For the text "a cat sat on the mat and [pivot] jumped into a box of paper" (assume encoder is 8 tokens and decoder is 4 tokens and a token is a word)
    • Encoder inputs: ["[pad]", "a", "cat", "sat", "on", "the", "mat", "and"] (start of the text is padded/truncated)
    • Decoder inputs: ["jumped", "into", "a", "box"] (end of the text is padded/truncated)
  • Can you confirm that this is indeed how this works? Can you also comment on more details regarding the sampling of pivot point? Is it done completely randomly or with some sort of restriction? i.e. sample pivot position using standard distribution centered on n_tokens/2 with some standard deviation (maybe like n_tokens/10?) where n_tokens is the number of tokens in the file.

I appreciate your support in this. I will probably be running into other questions in the near future which I will post here. I'm looking forward to understanding more and getting this working. Thanks in advance!