minimaxir/aitextgen

Finetuning 355M or larger GPT-2 models / Gradient Checkpointing

minimaxir opened this issue · 11 comments

Gradient checkpointing must be implemented to avoid going OOM when finetuning those models.

That is apparently done at the training level and PyTorch has tricks to do it easily, but I am having difficulty getting it to work correctly.

Are there any reference implementations for gradient checkpointing? I've heard it brought up as a PyTorch feature but I've not actually seen it in use.

I imagine you've already tried something like this, but I've taken a look at the pytorch docs for implementation details for gradient checkpointing here and here.

With normal pytorch modules it seems that it could be implemented during the forward pass using something like:

import torch.utils.checkpoint
def forward(self, inputs):
        return checkpoint(self.model,*inputs)

You could make this forward pass conditional on the model chosen, or as an optional param to the train class.
I see that in the init for the main aitextgen object there is a tf_gpt2 flag that could be passed to the trainer to check which model is being trained and execute the checkpointed forward pass based on that.

Yes, the correct implementation is something along those lines; apparently the Transformers GPT-2 forward() implementation is picky about its inputs.

I can give it another go.

I am getting OOM for even the smaller 124M model if the input file is bigger than 100 mb.
Also strangely breaking the file into smaller parts and trying to merge the token dataset has this error

/usr/local/lib/python3.6/dist-packages/aitextgen/TokenDataset.py in init(self, file_path, vocab_file, merges_file, texts, line_by_line, from_cache, header, save_cache, cache_destination, compress, block_size, tokenized_texts, text_delim, bos_token, eos_token, unk_token, pad_token, progress_bar_refresh_rate, **kwargs)
75 if tokenized_texts:
76 self.tokens = tokenized_texts
---> 77 self.num_subsets = self.tokens.shape[0] - block_size
78 self.block_size = block_size
79 self.file_path = "merged TokenDataset"

AttributeError: 'list' object has no attribute 'shape'

I am getting OOM for even the smaller 124M model if the input file is bigger than 100 mb.

The input dataset file is not related to these GPU OOM issues so you are hitting something else. You should not get OOM on the 124M model unless you are using a small GPU.

Also strangely breaking the file into smaller parts and trying to merge the token dataset has this error

That's unrelated, but a legit bug. Filed at #49

My input file was around 20 gb and got those OOM. I broke it down with "split -b" into 100mb chunks and I have no issues running it now.

@ganeshkrishnan1 after splitting the file into 100mb chunks you created several TokenDatasets and merged them? Or you just trained the model a little bit on every txt file separately?

I trained the model a bit, saved it and then reloaded it again to train the next file.

Awesome, thanks

Gradient checkpointing currently works for me right now by just setting the GPT2Config property
config.gradient_checkpointing = True.
I can fine-tune the 355M model on a 6GB RTX 2060 using this along with additional optimizations:

  • config.use_cache = False
  • training with param fp16=True (automatic mixed precision) and batch_size=1
  • switching from using the Adam optimizer to SM3
  • inside ai.train(...), add move_metrics_to_cpu=True to train_params, but unsure if this makes a difference

In total this uses ~5GB of VRAM with a small training file

Closing and unpinning due to 0.4.0