fjarri/reikna

Array.get is very slow

flavianh opened this issue · 5 comments

Hey there, I have noticed that transfering an array from the device to the host is slow in a case I was able to reproduce. Have a look at the following snippet:

from reikna.cluda import any_api
from reikna.linalg import MatrixMul
import numpy as np
import time
api = any_api()
thr = api.Thread.create()

def minimal_example(shape):
    matrix_a = thr.array(shape, dtype=np.float32)
    matrix_b = thr.array(shape, dtype=np.float32)
    matrix_output = thr.array(shape, dtype=np.float32)
    dot = MatrixMul(matrix_b, matrix_a, out_arr=matrix_output)
    dotc = dot.compile(thr)
    dotc(matrix_output, matrix_a, matrix_b)
    matrix_output.get()

import cProfile

cProfile.run('minimal_example((10000, 10000))', sort='time')

This yields:

         94297 function calls (92282 primitive calls) in 13.481 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1   13.353   13.353   13.353   13.353 gpuarray.py:248(get)
     6388    0.007    0.000    0.009    0.000 sre_parse.py:182(__next)
   234/38    0.006    0.000    0.022    0.001 sre_parse.py:379(_parse)
      271    0.003    0.000    0.009    0.000 version.py:208(__init__)
      148    0.003    0.000    0.004    0.000 sre_compile.py:207(_optimize_charset)
      313    0.003    0.000    0.030    0.000 __init__.py:2088(find_on_path)
     5893    0.003    0.000    0.011    0.000 sre_parse.py:201(get)
   431/36    0.003    0.000    0.009    0.000 sre_compile.py:32(_compile)
15989/15770    0.003    0.000    0.003    0.000 {len}
        2    0.003    0.001    0.003    0.001 {posix.read}
      965    0.002    0.000    0.004    0.000 posixpath.py:60(join)
        1    0.002    0.002    0.002    0.002 {posix.fork}
      265    0.002    0.000    0.002    0.000 version.py:351(_cmpkey)
        4    0.002    0.000    0.002    0.001 collections.py:237(namedtuple)
        1    0.002    0.002    0.091    0.091 __init__.py:15(<module>)
      372    0.002    0.000    0.002    0.000 {posix.stat}
     6302    0.002    0.000    0.002    0.000 {method 'endswith' of 'str' objects}
  601/210    0.002    0.000    0.002    0.000 sre_parse.py:140(getwidth)
      276    0.002    0.000    0.019    0.000 __init__.py:2460(from_location)
4263/4262    0.001    0.000    0.002    0.000 {isinstance}

I have checked my GPU card bandwidth (I am using Amazon's GPU machine, the small one with a GRID K520) and it gives me


 Device 0: GRID K520
 Quick Mode

 Host to Device Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)        Bandwidth(MB/s)
   33554432                     9390.1

 Device to Host Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)        Bandwidth(MB/s)
   33554432                     8814.8

 Device to Device Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)        Bandwidth(MB/s)
   33554432                     119122.4

My guess is that get takes a little too much time to copy a 380MB array. I suspected this was due to the thread synchronization step, but when I do thr.from_device(matrix_output, async=True) I don't get better results. So is the copy really slow or is there something I don't get?

Nodd commented

I think that dotc is non blocking, so get is waiting for the computation to finish before getting the data. You should be able to see this by first waiting for the computation, then getting the data.

How can I wait for completion? The doc is not very explicit about that, see here:

__call__(*args, **kwds)
Execute the computation.
Nodd commented

I don't know the reikna api very well. Try to pass async=False to thr = api.Thread.create() ? (of course, this is for profiling only)

@Nodd is correct, computations are executed asynchronously by default. You can either use async=False in the Thread constructor (or in Thread.create()), or call thr.synchronize() where you need it. There is currently no option for doing that in the computation call itself.

Seems to be resolved then. Closing.