apply_sharding() check does not care about sharding priorities
Opened this issue ยท 8 comments
๐ Describe the bug
The following, in my opinion valid, snippet fails with
import torchdata.datapipes as dp
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torchdata.dataloader2 import MultiProcessingReadingService, DataLoader2
pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.DISTRIBUTED)
pipe = pipe.sharding_filter(SHARDING_PRIORITIES.MULTIPROCESSING)
rs = MultiProcessingReadingService(num_workers=1)
dl = DataLoader2(pipe, reading_service=rs)
for x in dl:
print(x)
RuntimeError: Sharding twice on a single pipeline is likely unintended and will cause data loss. Sharding already applied to ShardingFilterIterDataPipe while trying to apply to ShardingFilterIterDataPipe
I.e. its currently not possible to first shard based on MPI rank, and then further shard based on (per-process) io worker rank, despite there being mechanisms built into sharding_filter
for that purpose.
This check is overly restrictive in my opinion.
Versions
main branch & torch nightly.
Notice that the above example is just for demonstration purposes. In a real pipeline these two sharding operations might take place in vastly different places. So replacing them with one sharding_filter with default priority is not an option.
Agree with you that the apply_sharding
can support multi-level sharding at the same branch.
@ejguan set_graph_random_seed does not account for different sharding priorities as well (https://github.com/pytorch/data/blob/main/torchdata/dataloader2/graph/settings.py#L31)
I find the intended behavior a bit problematic anyways:
The usual principle wrt multiprocessing right now is that every worker executes the same pipeline. If a sharding filter is encountered; then it only gets every i-th element. These two concepts are very easy to grasp.
However, now a pipe with a shuffle operation behaves differently depending on whether a sharding filter is present or not. E.g. for
pipe = IterableWrapper(range(10))
pipe = pipe.shuffle()
each worker will receive a different seed, whereas if we add a sharding filter afterwards, each worker will receive the same seed. I find this example unintuitive and also difficult to reason about. E.g. what about custom sharding operations that only adher to specific sharding priorities?
I would suggest letting the shuffle operation, or in case of the default shuffle, the user decide if it/he wants to use a worker specific or the distributed/global seed. Specifically:
- Change the signature of set_seed to
def set_seed(self, global_seed, worker_seed=None)
. - Add
use_worker_seed=False
(or True) to theShuffle(Iter)DataPipe
constructor. - Custom shuffle operations can implement their own priorities wrt which seed to use.
Keeping behaviors simple and explicitly asking which behaviors is wanted when there are two valid choices is a better course of action than to heuristically figure out what people want which always introduces complexity and potential misunderstandings.
It is also not clear to me right now how such shuffle operations are supposed to behave if one wants to set a fixed seed via Dataloader2.seed()
.
However, now a pipe with a shuffle operation behaves differently depending on whether a sharding filter is present or not. E.g. for
pipe = IterableWrapper(range(10)) pipe = pipe.shuffle()
each worker will receive a different seed, whereas if we add a sharding filter afterwards, each worker will receive the same seed. I find this example unintuitive and also difficult to reason about.
This is the way that we want to guarantee the order of data is the same across workers before sharding_filter
, then we can guarantee that each worker gets mutually exclusive examples.
I would suggest letting the shuffle operation, or in case of the default shuffle, the user decide if it/he wants to use a worker specific or the distributed/global seed. Specifically:
It's more complicated than it is. If you add another shuffle
after sharding_filter
like dp.shuffle().sharding_filter().shuffle()
, the second shuffle
will use worker-specific seeds like dp.shuffle()
.
@ejguan I believe this has been fixed by pytorch/pytorch#97287. Is that correct?
No, sorry, I'm afraid not.
A fix could look like this:
https://github.com/ejguan/pytorch/blob/f2cea87c1f9741e78c60c456bb0cd0f22d0689f7/torch/utils/data/graph_settings.py#L65
if len(sig.parameters) < 3:
sharded = dp.apply_sharding(num_of_instances, instance_id)
else:
sharded = dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
if sharded:
applied = dp
where apply_sharding()
is supposed to return a boolean indicating whether the pipe will be sharded or not.
I would also recommend switching from a RuntimeError
to a warning.
We should also include this testcase in https://github.com/ejguan/pytorch/blob/f2cea87c1f9741e78c60c456bb0cd0f22d0689f7/test/test_datapipe.py#L2875 :
numbers_dp = dp.iter.IterableWrapper(range(13))
sharded_dp = numbers_dp.sharding_filter(SHARDING_PRIORITY.DISTRIBUTED)
sharded_dp = sharded_dp.sharding_filter(SHARDING_PRIORITY.MULTIPROCESSING)
torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0, SHARDING_PRIORITY.DISTRIBUTED)
torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0, SHARDING_PRIORITY.MULTIPROCESSING)