pytorch/data

[RFC] Performance Profiling Tools

NivekT opened this issue ยท 3 comments

NivekT commented

๐Ÿš€ The feature

  1. Store usage statistics in Prefetcher

    • By tracking statistics within Prefetcher, we can reasonably determine whether upstream processes or downstream processes are faster. For example, the emptiness of the buffer queue may imply consumers are faster than producers. Users can insert this into various points in the pipeline to examine various behaviors. A common pattern we expect is to examine whether the pipeline is IO bound or compute bound.
    • #1141
  2. DataLoader2 main process

    • torch profilers (e.g. torch.profiler.profile) currently work with DataLoader2, however, it only tracks functions and DataPipes that are executed within the main process. Nonetheless, we should validate that the information is helpful if most of the computations take place within the main process (e.g. using InProcessReadingService or dispatching process.
    • After 1 is completed, we can add APIs to DataLoader2 to fetch the relevant statistics from Prefetcher's buffer, such as the one that exists at the end of the main loop. It should allow users to examine whether the model is consuming faster than the preparation of samples.
    • PR pending
    • Tutorial pending
  3. DataLoader2 worker process profiling

    • Two main options under considerations are:
      1. Attaching the profiler to worker process in order to get worker level metrics/trace. This will allow us to use existing profilers without re-implementing their features.
      2. MultiprocessingReadingService can provide methods to retrieve and aggregate metrics from certain DataPipes (mainly Prefetcher)
  4. Integration with other tools (e.g. tracers)

    • We will likely want main and worker processes' to be visible within tracers (e.g. useful when integrated with TorchTNT).

Motivation, pitch

This set of tools and features aim to answer the questions:

  1. Is my model training bottlenecked by data loading?
  2. If so, which part of the pipeline? IO? Compute?

Alternatives

No response

Additional context

Comments and suggestions are welcomed.

ejguan commented

How about updating the title with [RFC] and pinning this issue?

It would be good to have the option to profile any pipe, not only prefetcher. This could be achieved by having a ProfilerPipe that wraps the source pipe and measures the timings of it. This also covers the prefetcher case. Optionally, it could be assigned a label for easier reading/printing. Using graph manipulation functions such profiling pipes could be attached to either all pipes or only a selected subset.

To clarify what I meant, here's a rough sketch:

@functional_datapipe('profile')
class ProfilerIterDataPipe(IterDataPipe):
    def __init__(self, dp, label=None, iters_per_measurement=1):
        self.dp = dp
        self.label = label
        self.measurements = []
        self.iters_per_measurement = iters_per_measurement

    def __iter__(self):
        i = 0
        start = None
        it = iter(self.dp)
        try:
            while True:
                if i == 0:
                    start = time.time()
                elem = next(it)
                i += 1
                if i == self.iters_per_measurement:
                    self.measurements.append((time.time() - start) / i)
                    i = 0
                yield elem
        except StopIteration:
            pass
        finally:
            if i > 0:
                self.measurements.append((time.time() - start) / i)

Obviously this will not take into account the time spent waiting on upstream pipes, but by traversing the graph upwards, one could potentially infer that time as well.