/storydalle

Primary LanguagePythonMIT LicenseMIT

StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation

PyTorch code for the ECCV 2022 paper "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation".

[Paper] [Model Card] [Demo](work in progress)

Update: The Demo link is live now, it is a temporary 72-hour link and we will try to make it permanent on HuggingFace Spaces (GPU version).

image

image

Training

Prepare Repository:

Download the PororoSV dataset and associated files from here (updated) and save it as ./data/pororo/.
Download the FlintstonesSV dataset and associated files from here and save it as ./data/flintstones
DiDeMoSV dataset is coming soon.

This repository contains separate folders for training StoryDALL-E based on minDALL-E and DALL-E Mega models i.e. the ./story_dalle/ and ./mega-story-dalle models respectively.

Training StoryDALL-E based on minDALL-E:

  1. To finetune the minDALL-E model for story continuation, first migrate to the corresponding folder:
    cd story-dalle
  2. Set the environment variables in train_story.sh to point to the right locations in your system. Specifically, change the $DATA_DIR, $OUTPUT_ROOT and $LOG_DIR if different from the default locations.
  3. Download the pretrained checkpoint from here and save it in ./1.3B
  4. Run the following command: bash train_story.sh <dataset_name>

Training StoryDALL-E based on DALL-E Mega:

  1. To finetune the DALL-E Mega model for story continuation, first migrate to the corresponding folder:
    cd mega-story-dalle
  2. Set the environment variables in train_story.sh to point to the right locations in your system. Specifically, change the $DATA_DIR, $OUTPUT_ROOT and $LOG_DIR if different from the default locations.
  3. Pretrained checkpoints for generative model and VQGAN detokenizer are automatically downloaded upon initialization. Download the pretrained weights for VQGAN tokenizer from here and place it in the same folder as VQGAN detokenizer.
  4. Run the following command: bash train_story.sh <dataset_name>

Inference

Links to pretrained checkpoints and inference instructions coming soon!