mittagessen/kraken

Training on large binary dataset

Closed this issue · 7 comments

I try to train a recognition model on a very large dataset. In order to speed up training, I have compiled my dataset into multiple arrow files of 10 000 lines each.

I initially assumed that ArrowIPCRecognitionDataset was just keeping in memory a reference of the arrow files. However, when I looked at ArrowIPCRecognitionDataset's add method, I realized that the whole dataset is loaded in memory before training start:

https://github.com/mittagessen/kraken/blob/a0c395727c011d3283b34b5f7a9ef6d85970e6d0/kraken/lib/dataset/recognition.py#LL191C1-L195C37

Wouldn't it be better to load the arrow files one by one during training?

Nope, they're aren't loaded into memory if a zero-copy concatenation is possible, i.e. if the table can be memory mapped. This principal case where this isn't possible is when using built-in fixed splits because filtering a table requires it to be loaded in its entirety.

In any case we need to iterate over the whole dataset before training to establish the alphabet after text transformations and that's where you might see some memory use but that's transient.

In any case we need to iterate over the whole dataset before training to establish the alphabet after text transformations and that's where you might see some memory use but that's transient.

It's not what I'm talking about.

Nope, they're aren't loaded into memory if a zero-copy concatenation is possible, i.e. if the table can be memory mapped. This principal case where this isn't possible is when using built-in fixed splits because filtering a table requires it to be loaded in its entirety.

It's weird. My dataset doesn't contain any split but the table seems to get loaded in memory:

(Pdb) print(sys.getsizeof(self.arrow_table) * 0.000001)
12197.935502

This corresponds to the process memory consumption displayed by top. I tried with multiple datasets and always get the same issue.

Please reopen the issue.

You can't use either sys.getsizeof() or top to determine the actual memory use of an object/process that uses memory mapping because the former prints whatever __sizeof__ returns and the latter shows virtual memory. Concatenation of tables doesn't result in them being loaded completely into memory. Look at the size of the arrow memory pool to convince yourself:

import sys
import pyarrow as pa
file1 = 'foo.arrow'
file2 = 'bar.arrow'
with pa.memory_map(file1, 'rb') as source:
    ds_table_1 = pa.ipc.open_file(source).read_all()
print(f'{pa.total_allocated_bytes()} {sys.getsizeof(ds_table_1)}')
with pa.memory_map(file2, 'rb') as source:
    ds_table_2 = pa.ipc.open_file(source).read_all()
print(f'{pa.total_allocated_bytes()} {sys.getsizeof(ds_table_2)}')
arrow_table = pa.concat_tables([ds_table_1, ds_table_2])
print(f'{pa.total_allocated_bytes()} {sys.getsizeof(arrow_table)}')

Thanks a lot! I will check that out tonight.

Maybe I'm not using the best way to measure memory consumption for memory mapped objects. However, the memory overflow errors I'm getting from the host (not the GPU) are not virtual. So I think there's issue somewhere.

The skip_empty_lines switch might be the culprit here. The filter probably causes everything to get loaded. If you're sure there aren't any lines with single whitespace and similar in there you can just disable it.

773cc00 solved the issue, thanks!

Ok, perfect. It's just a shortcut that doesn't run the filter if there aren't any empty lines but as soon there's even a single one everything will get loaded again. I might have to filter manually instead of using pyarrow functions to avoid this.