pytorch/data

Support for proper Distributed & Multiprocessing Sharding

Opened this issue ยท 0 comments

๐Ÿš€ The feature

In MPI-based training, each process is independent from each other. Each training process might want to speed up dataloading using multiprocessing (MP). This requires data sharding to take place on two levels:

A. On a distributed level, usually resulting in big(ger) shards.
B. On a MP level later on, further splitting those big shards among worker processes.

While (A.) might potentially shard on a coarser, logical scale (e.g. on years or months if working with climatological data), (B.) might potentially shard directly on already loaded data (e.g. on indices of the previous shards).

Right now, combining distributed & MP sharding in torchdata faces two hurdles that need addressing:

  1. Due to optional check in , there can only be a single sharding_pipe(). This check however does not take into account if a sharding pipe only operates on a specific sharding group / priority. This issue is already tracked by #1082. A simple fix for this is to drop the check all together.
  2. torchdata assumes a single sharding (and distribution) model: Namely that distributed & MP shards are on the same logical level and that those are distributed in a round-robin fashion to worker processes. This is enforced in https://github.com/pytorch/data/blame/main/torchdata/dataloader2/utils/worker.py#L82 which prevents more general sharding strategies.

Overall, these two hurdles need addressing via monkey patching at the moment to enable more general sharding strategies (see motivation for an use case and example of such a strategy). https://github.com/sehoffmann/atmodata/blob/6a7c2974a5de1354a7156d427bf53899fc6c0177/atmodata/patching.py shows what patches need to be done.
Specifically:

  • The check in apply_sharding() needs to be removed
  • process_init_fn() should call apply_sharding() on the whole pipe, not only on non-dispatching branches.
  • pipe.repeat(n_workers).sharding_round_robin_dispatch() needs to be used as a workaround to distribute the same shard to all workers. For this, an additional pipe should be introduced (just dispatch()).

Instead of having to monkey-patch, torchdata should be less restrictive wrt. sharding and distribution strategies.

Motivation, pitch

I'm working with climatological timeseries data on the terabyte scale. The sharding strategy and MP strategy that, in my humble opinion, makes the most sense for this use case looks like this:

  1. Shard (distributed) across the time-dimension on a logical level. Single shards could e.g. represent a single month, be contained in a single file, and be multiple gigabytes in size. These shards are pre loaded by the main process via network and in parallel.
  2. The same shard is distributed to each worker process via shared memory (to reduce memory overhead). E.g. each worker process sees the same shard/month. Now this "super-shard" is sharded further among worker processes by accessing only a subset of the indices. The time-resolution could e.g. be 1h.
  3. Batches from individual workers are aggregated by the main thread again.

Overall, this pipelines roughly looks like this:

# Main Thread - Pre-loading
months = IterableWrapper(["1979-Jan", "1979-Feb", ..., "2020-Dec"])
pipe = months.shuffle().sharding_filter(DISTRIBUTED)
pipe = pipe.load_data().prefetch()
pipe = pipe.repeat(n_workers).round_robin_dispatch()

# Worker Process
pipe = pipe.unroll_indices() # -> yields (idx, data) tuples where data is the whole shard and idx are akin to enumerate()
pipe = pipe.shuffle().sharding_filter(MULTIPROCESSING)
pipe = pipe.do_work_on_sample()
pipe = pipe.batch()

# Main Thread - Post-process
pipe = pipe.non_replicable()  # non-replicable No-Op pipeline to force transfer to main thread
pipe = pipe.post_process()

Why can't individual worker processes operate independently on the same shards as in (1.), i.e. months?

Shards can be fairly big in size. If every worker would operate on independent shards then memory consumption might explode. Furthermore, worker processes might compete for shared network IO bandwidth. Also, depending on the shard size, there are potentially not that many shards in the dataset. This would then imposes a maximum on the number of GPUs for training.

Why can't you reduce the shard size then? E.g. weeks instead of months

We are cropping timeseries from those shards. We thus always have some data waste at the end (or start) of each shard from which we can't crop. Reducing the shard size would increase the amount of data we would need to throw away. Furthermore, loading a few big shards via network is much more efficient than loading many small shards, and we want to utilize our network interface as much as possible for maximum throughput.

Why can't you shard directly on index level and then distribut in a round-robin fashion?

This would be horrendously slow.

Overall, the difficulties with this kind of data stems from two facts:

  • Samples are not iid. due to the timeseries structure. This results in an overlap of data among workers.
  • The data is too big to be kept in memory and need to be fetched on-demand from a FS that is potentially network-based.

Alternatives

No response

Additional context

No response