OrigamiDream/gato

Dataset help

ThomasRochefortB opened this issue · 3 comments

Hey @OrigamiDream !

I am also looking into implementing the Gato paper. In order to complement your work, I decided to start looking into the datasets mentioned in the paper first. I started investigating to see if we could access the vision/language datasets used and if we could train SOTA agents to generate the control datasets: torch-gato/datasets

From my initial investigation, we have open-source access to about 83.62% of the datasets/environments used during training (according to the sample weight). If we use similar open-source variants of the private language/vision datasets, this number climbs to 94.32%. There is still a lot of work to do in order to build a training infrastructure for all the SOTA or near-SOTA agents to generate the control environment expert data.

Let me know if you are still interested in this implementation and how we could collaborate!
Cheers,
Thomas

Hi @ThomasRochefortB

I am so glad to see the datasets you've found.
Your investigation solves one of the hardest part in this project.

I'm still interested in this project and still finding way to implement training strategy in TensorFlow graph execution mode.
In the paper, they mixed datasets across in different tasks, like Atari images (64x80) with its action tokens, and a random image (224x224) with its text tokens for caption generation.
This makes the inputs to main transformer blocks not fixed, meaning cannot be run in graph execution mode.

I have two strategies to solve this problem.

  1. Dropping the performance advantages from graph execution mode and train the model end-to-end.
  2. Pre-training the embeddings for image, actions (discrete) and continuous values using self-supervised learning (which haven't been mentioned in the paper at all) and transfer to the main transformer blocks.

I haven't decided what would be the best option for me, yet.

Thanks

@OrigamiDream

The way I understand the paper, all of the inputs are tokenized which implies that the main transformer block sees a constant shape at its input.

The way I would envision the implementation is to build a "dataloader" that combines the tokenization and embedding part of the process (Section 2.1 and 2.2 of the paper). The first sentence of section 2.2 kind of hints towards that:
"After tokenization and sequencing, we apply a parameterized embedding function to each token to produce the final model input."

The dataloader would be conditional on the modality of the input data (vision vs environment observations vs text) and would apply the corresponding tokenization and embedding functions.

But at the end of it, the input to the transformer is a simple stream of tokens!

Hence, we must pre-train the ViT for image patch tokens prior to make stream of tokens.
The single ResNet block, as you said in #2 , there are already well pre-trained ResNet backbone in the wild for 224x224.
However, in Atari case, which receives 64x80 image, I couldn't find the well-trained backbone.