alxndrTL/mamba.py

Question about the getdata() function

peterphancong opened this issue · 1 comments

Thank you for your great implementation, I am preparing to apply it to my project
However, I am confused about the functions get_data() and get_batch().
As my understanding, in the get_data(), you create a very long sequence by joining all the examples of the dataset, and each example is a very long story (or a document)
for example in dataset: if example['tokens']: tokens = [vocab[token] for token in example['tokens']] data.extend(tokens)
then batching this very long list into smaller batches by :
def get_batch(data, seq_len, idx): src = data[:, idx:idx+seq_len] target = data[:, idx+1:idx+seq_len+1] return src, target
So, each batch is a part of this joined string. This approach seem to be strange to me, because I think each document should be a string, and a batch is a set of documents.
Please correct me if my understand is incorrect.

Hello,
Yes indeed, tbh I was not behind the creation of this script and didn't look at properly. You should be better off with more classic loaders like the one from nanogpt or even modded-nanogpt.
Sorry for the late response.