pytorch/data

apply_sharding() check does not care about sharding priorities

Opened this issue ยท 8 comments

๐Ÿ› Describe the bug

The following, in my opinion valid, snippet fails with

import torchdata.datapipes as dp
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torchdata.dataloader2 import MultiProcessingReadingService, DataLoader2

pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.DISTRIBUTED)
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.MULTIPROCESSING)
    
rs = MultiProcessingReadingService(num_workers=1)
dl = DataLoader2(pipe, reading_service=rs)

for x in dl:
    print(x)

RuntimeError: Sharding twice on a single pipeline is likely unintended and will cause data loss. Sharding already applied to ShardingFilterIterDataPipe while trying to apply to ShardingFilterIterDataPipe

I.e. its currently not possible to first shard based on MPI rank, and then further shard based on (per-process) io worker rank, despite there being mechanisms built into sharding_filter for that purpose.

This check is overly restrictive in my opinion.

Versions

main branch & torch nightly.

Notice that the above example is just for demonstration purposes. In a real pipeline these two sharding operations might take place in vastly different places. So replacing them with one sharding_filter with default priority is not an option.

ejguan commented

Agree with you that the apply_sharding can support multi-level sharding at the same branch.

@ejguan set_graph_random_seed does not account for different sharding priorities as well (https://github.com/pytorch/data/blob/main/torchdata/dataloader2/graph/settings.py#L31)

I find the intended behavior a bit problematic anyways:

The usual principle wrt multiprocessing right now is that every worker executes the same pipeline. If a sharding filter is encountered; then it only gets every i-th element. These two concepts are very easy to grasp.

However, now a pipe with a shuffle operation behaves differently depending on whether a sharding filter is present or not. E.g. for

pipe = IterableWrapper(range(10))
pipe = pipe.shuffle()

each worker will receive a different seed, whereas if we add a sharding filter afterwards, each worker will receive the same seed. I find this example unintuitive and also difficult to reason about. E.g. what about custom sharding operations that only adher to specific sharding priorities?

I would suggest letting the shuffle operation, or in case of the default shuffle, the user decide if it/he wants to use a worker specific or the distributed/global seed. Specifically:

  1. Change the signature of set_seed to def set_seed(self, global_seed, worker_seed=None).
  2. Add use_worker_seed=False (or True) to the Shuffle(Iter)DataPipe constructor.
  3. Custom shuffle operations can implement their own priorities wrt which seed to use.

Keeping behaviors simple and explicitly asking which behaviors is wanted when there are two valid choices is a better course of action than to heuristically figure out what people want which always introduces complexity and potential misunderstandings.

It is also not clear to me right now how such shuffle operations are supposed to behave if one wants to set a fixed seed via Dataloader2.seed().

ejguan commented

However, now a pipe with a shuffle operation behaves differently depending on whether a sharding filter is present or not. E.g. for

pipe = IterableWrapper(range(10))
pipe = pipe.shuffle()

each worker will receive a different seed, whereas if we add a sharding filter afterwards, each worker will receive the same seed. I find this example unintuitive and also difficult to reason about.

This is the way that we want to guarantee the order of data is the same across workers before sharding_filter, then we can guarantee that each worker gets mutually exclusive examples.

I would suggest letting the shuffle operation, or in case of the default shuffle, the user decide if it/he wants to use a worker specific or the distributed/global seed. Specifically:

It's more complicated than it is. If you add another shuffle after sharding_filter like dp.shuffle().sharding_filter().shuffle(), the second shuffle will use worker-specific seeds like dp.shuffle().

@ejguan I believe this has been fixed by pytorch/pytorch#97287. Is that correct?

No, sorry, I'm afraid not.

A fix could look like this:
https://github.com/ejguan/pytorch/blob/f2cea87c1f9741e78c60c456bb0cd0f22d0689f7/torch/utils/data/graph_settings.py#L65

if len(sig.parameters) < 3:
    sharded = dp.apply_sharding(num_of_instances, instance_id)
else:
    sharded = dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
if sharded:
    applied = dp

where apply_sharding() is supposed to return a boolean indicating whether the pipe will be sharded or not.
I would also recommend switching from a RuntimeError to a warning.

We should also include this testcase in https://github.com/ejguan/pytorch/blob/f2cea87c1f9741e78c60c456bb0cd0f22d0689f7/test/test_datapipe.py#L2875 :

numbers_dp = dp.iter.IterableWrapper(range(13))
sharded_dp = numbers_dp.sharding_filter(SHARDING_PRIORITY.DISTRIBUTED)
sharded_dp = sharded_dp.sharding_filter(SHARDING_PRIORITY.MULTIPROCESSING)
torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0, SHARDING_PRIORITY.DISTRIBUTED)
torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0, SHARDING_PRIORITY.MULTIPROCESSING)