pytorch/xla

Is there a reason why the TPU colab examples use the multi-threading approach instead of multiprocessing?

Santosh-Gupta opened this issue · 28 comments

❓ Questions and Help

I am looking colab examples ( https://github.com/pytorch/xla/tree/master/contrib/colab ) and I noticed that they use a multi-threading approach instead of the multi-processing approach. The following lines are taken from one of the colab examples.

import torch_xla.distributed.data_parallel as dp
devices = (
    xm.get_xla_supported_devices(
        max_devices=num_cores) if num_cores != 0 else [])
model_parallel = dp.DataParallel(ResNet18, device_ids=devices)

These lines are described as the multithreading approach in the API guide

https://github.com/pytorch/xla/blob/master/API_GUIDE.md#running-on-multiple-xla-devices-with-multithreading

Multiple devices are acquired in the same process with xm.get_xla_supported_devices().

The model is wrapped in dp.DataParallel and passed both the training loop and dataloader.

I am wondering why the colab notebooks use a multithreading approach, when the API guide recommends using the multi-processing. From the API guide

Running on multiple XLA devices using processes (see above) is preferred to using threads. If, however, you want to use threads then PyTorch/XLA has a DataParallel interface.

Perhaps the TPUs are unable to use the multiprocessing approach?

The TPU Colab notebooks use the DataParallel multithreading API mainly because the MultiProcessing API wasn't really mature then. But there is no reason why the Colab notebooks shouldn't be using the MultiProcessing APIs.

We'll update those notebooks soon, though MP API should work on Colab as of now.

@jysohn23 multiprocessing will work in a notebook now? Because when I have tried it in the past, it gets stuck at spawning.

I haven't tried it myself so can't say for sure to be honest. Was there any error you were seeing @tmabraham?

I'll try to get to this and try this out myself sometime soon.

@jysohn23 Running train_test_mp_mnist.py directly in Colab, I get this error:

Exception                                 Traceback (most recent call last)
<ipython-input-5-fdb8d80a3138> in <module>()
    161 
    162 if __name__ == '__main__':
--> 163   xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
    164 
/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon)
    125   nprocs = _pre_fork_setup(nprocs)
    126   return torch.multiprocessing.spawn(
--> 127       _start_fn, args=(fn, args), nprocs=nprocs, join=join, daemon=daemon)

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon)
    169 
    170     # Loop on join until it returns True or raises an exception.
--> 171     while not spawn_context.join():
    172         pass

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    110                 raise Exception(
    111                     "process %d terminated with exit code %d" %
--> 112                     (error_index, exitcode)
    113                 )
    114 

Exception: process 1 terminated with exit code 1

Is that the only error you get?

Can you link the notebook you used?

I think a good starting point would be to convert the example colab notebooks to use multiprocessing, it may help isolate the issue.

@dlibenzi Yep but none of the processes started.
@Santosh-Gupta I just copied this code into a Colab cell and run that cell, with the TPU hardware accelerator.

Here's a notebook showing the issues with MP in Colab. I'm pretty sure these all apply to Jupyter/JupyterLab too.
The issue is you can't use the spawn start method in a notebook as it tries to pickle the entry function and this doesn't work as functions defined in a notebook aren't linked to python files.
If the entry function is defined in a script (say importing from the sample .py file) it does launch and run, but unfortunately the output goes to the colab logs not the notebook. Just running the script with !python does work showing that apart from the issues spawning it does work.

I gather the forkserver start method used by default for PyTorch DataLoaders works fine in colab given that people run PyTorch fine on colab. Though I guess there may be a reason spawn was used in torch_xla instead.

This is an hack, but PyTorch forcibly use the spawn multiprocessing context for some reason.
I read CUDA does not like fork, but we do not use CUDA so it might be worth a try.

import torch
import multiprocessing

def mp_main(i):
  print('Hello from multiprocessing ({})'.format(i))

_GETC = multiprocessing.get_context

def _getc(n):
  return _GETC('fork')

multiprocessing.get_context = _getc

if __name__ == '__main__':
  torch.multiprocessing.spawn(mp_main)

With that hack it works fine, including a start function in the notebook code.
I also tested with the forkserver start method but that fails with the same issue as spawn (the PyTorch docs state that CUDA tensor sharing works with forkserver).
As dlibenzi notes PyTorch uses spawn explicitly. So even setting the python multiprocessing start method won't affect torch.multiprocessing.spawn. There's also torch.multiprocessing stuff to nicely handle errors from child processes so it's not entirely trivial to replace it with straight python multiprocessing.

Oh, nice, thanks for testing it!
We do not use anything ATM about torch.multiprocessing, so we could use Python's one.
But I feel that we might want to use some of its functionalities at some point, so we might be better living with it.
Maybe we need to check with pytorch folks about what are the real issues in using fork, because if it is only a CUDA issue, it does not affect XLA/TPU, so we could have a nicer override WRT the hack above.

Yeah, the docs talk about using other methods with torch.multiprocessing but they only provide the convenience wrapping around spawn. That all seemed to work fine when forced to fork, but could be subtleties.
As PyTorch does apparently support forkserver for CUDA tensors the ability to choose a method for spawn, keeping its error handling and graceful cleanup, seems to make some sense even for typical CUDA PyTorch use.

Maybe asking for something like this to be upstreamed?

diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py
index e084333671..21b6d01168 100644
--- a/torch/multiprocessing/spawn.py
+++ b/torch/multiprocessing/spawn.py
@@ -118,7 +118,7 @@ class SpawnContext:
         raise Exception(msg)
 
 
-def spawn(fn, args=(), nprocs=1, join=True, daemon=False):
+def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
     r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
 
     If one of the processes exits with a non-zero exit status, the
@@ -142,6 +142,9 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False):
         join (bool): Perform a blocking join on all processes.
         daemon (bool): The spawned processes' daemon flag. If set to True,
                        daemonic processes will be created.
+        start_method (string): The multiprocessing start method to be used
+            to create new processes. If CUDA is available and used, it must
+            be set to ``spawn``.
 
     Returns:
         None if ``join`` is ``True``,
@@ -149,7 +152,7 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False):
 
     """
     _python_version_check()
-    mp = multiprocessing.get_context('spawn')
+    mp = multiprocessing.get_context(start_method)
     error_queues = []
     processes = []
     for i in range(nprocs):

Yeah, that's what I'd thought. Only issue I could see was if other stuff only worked with spawn (apart from the documented issues with CUDA tensors). But my basic testing seemed fine.
Seems easier than trying to replicate spawn in torch_xla (and works for others who want to use fork/forkserver with PyTorch).

Thanks @dlibenzi for the fix. Multiprocessing code works fine in Colab now. However, there is still an error. If we try to rerun the code cell, we get a RecursionError, which means I have to restart the session to run the multiprocessing code again. For me, it's just a minor inconvenience, but obviously, this could be a huge problem for other tasks. Here is the error:

---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-2-794d8484d6fc> in <module>()
    120 
    121 if __name__ == "__main__":
--> 122   xmp.spawn(train_loop,args=())
    123 

/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon)
    125   nprocs = _pre_fork_setup(nprocs)
    126   return torch.multiprocessing.spawn(
--> 127       _start_fn, args=(fn, args), nprocs=nprocs, join=join, daemon=daemon)

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon)
    150     """
    151     _python_version_check()
--> 152     mp = multiprocessing.get_context('spawn')
    153     error_queues = []
    154     processes = []

<ipython-input-2-794d8484d6fc> in _getc(n)
     18 _GETC = multiprocessing.get_context
     19 def _getc(n):
---> 20   return _GETC('fork')
     21 multiprocessing.get_context = _getc
     22 

<ipython-input-1-e6f0a9d5e79d> in _getc(n)
     18 _GETC = multiprocessing.get_context
     19 def _getc(n):
---> 20   return _GETC('fork')
     21 multiprocessing.get_context = _getc
     22 

... last 1 frames repeated, from the frame below ...

<ipython-input-1-e6f0a9d5e79d> in _getc(n)
     18 _GETC = multiprocessing.get_context
     19 def _getc(n):
---> 20   return _GETC('fork')
     21 multiprocessing.get_context = _getc
     22 

RecursionError: maximum recursion depth exceeded

Looks like you've re-run the patching code cell. I think a slightly safer version would be (not tested in context):

if '_GETC' not in locals(): _GETC = multiprocessing.get_context

to prevent overwriting with the patched version. Obviously not really intended as an actual fix as is but that would at least be safer for testing.

@mruberry Can you upstream the patch at previous comment?

#1217 (comment)

@Santosh-Gupta As of now you should be able to use multiprocessing API on Colab with TPUs: #1347. Do notice that you'll need something like the following when spawning the processes:

xmp.spawn(..., start_method='fork')

Our Colab examples have been updated to use the MP API.

@jysohn23 Running train_test_mp_mnist.py directly in Colab, I get this error:

Exception                                 Traceback (most recent call last)
<ipython-input-5-fdb8d80a3138> in <module>()
    161 
    162 if __name__ == '__main__':
--> 163   xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
    164 
/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon)
    125   nprocs = _pre_fork_setup(nprocs)
    126   return torch.multiprocessing.spawn(
--> 127       _start_fn, args=(fn, args), nprocs=nprocs, join=join, daemon=daemon)

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon)
    169 
    170     # Loop on join until it returns True or raises an exception.
--> 171     while not spawn_context.join():
    172         pass

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    110                 raise Exception(
    111                     "process %d terminated with exit code %d" %
--> 112                     (error_index, exitcode)
    113                 )
    114 

Exception: process 1 terminated with exit code 1

Have you solved this problem?

It looks like you've copy pasted the content of test_train_mp_mnist.py onto Colab? Like mentioned in #1217 (comment) you need to call xmp.spawn with start_method='fork' on Colab since the script isn't a file and thus not pickle-able. Use this Colab instead: https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/mnist-training-xrt-1-15.ipynb.

Sorry, I missed your reply. Do you mean that you can't run the command twice on Colab? Which command are you referring to?

didn't work on colab even used start_method='fork'

---------------------------------------------------------------------------

Exception                                 Traceback (most recent call last)

<ipython-input-12-7850f341f2f8> in <module>()
     14     xm.optimizer_step(optimizer)
     15 
---> 16 xmp.spawn(_mp_fn, args=(), start_method='fork')

2 frames

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    117         msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
    118         msg += original_trace
--> 119         raise Exception(msg)
    120 
    121 

Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 116, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 109, in _setup_replication
    xm.set_replication(str(device), [str(device)])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 194, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 181, in xla_replication_devices
    .format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8
def _mp_fn(index):
  device = xm.xla_device()
  para_loader = pl.ParallelLoader(train_loader, [device])

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.01)

  for data, target in para_loader.per_device_loader(device):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

xmp.spawn(_mp_fn, args=(), start_method='fork')

using nightly version of torch/xla

Can you show all the code?
Seems link you are using XLA devices in global scope, before calling xmp.spawn() ...

Yes, and I now know I can not call xm.xla_device() before calling xmp.spawn. Maybe adding a note to the torch xla tutorial will be helpful to others.

Thanks, we added a section to the API guide: #1755

how do i use gpu script on colab tpu? How to enable cuda in tpu

If you want to use GPU you need to change the Colab runtime type.
If you want to use TPU and you code has a bunch of CUDA specific instructions, you will need to change the code and follow the guidelines we outline in our Colab examples.