EleutherAI/gpt-neox

Is there a way to train on the entire dataset for N epochs without specifying train-iters?

javirandor opened this issue · 4 comments

Hello everyone,

Thanks in advance for your help :)

I want to train a model on a specific dataset for N=1 epochs. I would ideally want the model to see all the data and only once. I could not find a trivial way of doing this using the config. Maybe I am missing something?

Then, I tried to find a way to compute the value for train-iters that would provide me with an equivalent behaviour, but I found an expected gap between the value I obtain, and the value that actually works. This is what I did.

From the output of the script + some additional printing for debugging:

number of tokens per epoch: 2113896210

using:
 number of documents:       1347665
 sequence length:           2048
 total number of samples:   2064351

So I assume I am going to train on 2113896210 tokens. Since I am using 2048 as my sequence length, this results in 1032175.88 -> 1032176 chunks.

My current setup has 4 GPUs and 32 batch size with no accumulation. So, the chunks per GPU are 1032176 / 4 = 258044 and we process 32 per iteration so 258044 / 32 = 8063.87 -> 8064 iterations are required.

However, if I set train-iters=8064, I see that the logs print

number of epochs:          2

which means that there is not enough data for 8064 iterations and some will be seen twice. I tried manually a bunch of number in the vicinity to account for potential rounding errors, but none works. I had to manually go down and 8023 is the highest value resulting in the target 1 epoch.

Can someone maybe help me understand where this offset comes from, or if there is a native way of training until all the data is seen?

Hello, to verify whether this is a bug, can you please divide your budget of iterations by 1.006 and let us know how many epochs that would correspond to ? Thank you !

Thanks for your prompt reply! Btw, I am working with Pythia so I am using v1.0.

If I compute the number quickly, this would be 8064 / 1.008 = 8000. This results in 1 epoch but not all data will be seen since I can go up to 8023 and still get 1 epoch.

Out of curiosity, may I ask what 1.008 stands for?

Happy to check more stuff to help debug this :)

8023 will be 1 epoch and 8024 will be 2, yes. This is expected behaviour with Megatron data pipelines.
Check this comment for reference or read below first for the explanation.

The reason is that you would in general have n data sources, with associated weights determining the sampling probability of each data source. Suppose you're training over 10000 sequences, coming from data sources data1 and data2, each with probability 50%. When sampling, you will roll the dices for samples from both sources, and the expected value of the number of samples from each data source will be 5000. However, as you can expect, this is a random process and you might actually end up sampling 4999 from data1 and 5001 from data2 for example with a given seed. In other words, you need to leave a margin in the number of sequences you'll sample from each dataset to account for this variance.

A margin that works well in practice is 0.5% of the number of samples, and that's what Megatron uses. Hence why I asked you to check with 1.006 that it would be one epoch, or why it's normal that 8023 (8064/8023 > 1.005) gives you 1 epoch worth of sample indices, and 8024 gives you 2 epochs (8064/8024 < 1.005).

Now of course, in your specific case with 1 data source, this 0.5% buffer is not useful -- there will be no variance in the expected number of sequences you'll have seen from your data source. Generally, my advice is that missing a few iterations won't really matter. However, if you really want to train exactly once on every sequence without this buffer, you can go to this line and turn the 1.005 factor into 1 (or remove it altogether). Make sure to re-enable it if you start training on more data sources.

More subjectively, re: why not make 1 data source an exception and disable the buffer in that case, my personal opinion is that it's preferable for the behaviour to be the same, independently of the number of data sources. :-) But if you want to do it automatically, you can simply add an if len(weights) == 1: condition that would set the buffer factor to be 1.0 and 1.005 otherwise.

Thanks a lot for the insights and spending the time looking into this!! It was really helpful :)