tensor-bridge
is a light-weight library that achieves inter-library tensor transfer by native cudaMemcpy
call with minimal overheads.
import torch
import jax
from tensor_bridge import copy_tensor
# PyTorch tensor
torch_data = torch.rand(2, 3, 4, device="cuda:0")
# Jax tensor
jax_data = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))
# Copy Jax tensor to PyTorch tensor
copy_tensor(torch_data, jax_data)
# And, other way around
copy_tensor(jax_data, torch_data)
copy_tensor_with_assertion
before starting experiments. copy_tensor_with_assertion
will raise an error if copy doesn't work.
If copy_tensor_with_assertion
raises an error, you need to force the tensor to be contiguous:
# PyTorch example
# different layout raises an error
a = torch.rand(2, 3, device="cuda:0")
b = torch.rand(3, 2, device="cuda:0").transpose(0, 1)
copy_tensor_with_assertion(a, b) # AssertionError !!
# make both tensors contiguous layout
b = b.contiguous()
copy_tensor_with_assertion(a, b)
Since copy_tensor_with_assertion
does additional GPU-CPU transfer internally, make sure that you switch to copy_tensor
in your experiments. Otherwise your training loop will be significantly slower.
- Fast inter-library tensor copy.
- Inter-GPU copy (I believe this is supported with the current implementation. But, not tested yet.)
- PyTorch
- Jax
- nnabla
If pip installation doesn't work, please try installation from source code.
You can install a pre-built package.
pip install tensor-bridge
Your macine needs to install nvcc
to compile a native code and Cython
to compile .pyx
files.
pip install Cython==0.29.36
pip install tensor-bridge
Pre-built packages for other Python versions are in progress.
Your macine needs to install nvcc
to compile a native code and Cython
to compile .pyx
files.
git clone git@github.com:takuseno/tensor-bridge
cd tensor-bridge
pip install Cython==0.29.36
pip install -e .
Your machine needs to install NVIDIA's GPU and nvidia-driver to execute tests.
./bin/build-docker
./bin/test
To benchmark round trip copies between Jax and PyTorch:
./bin/build-docker
./bin/benchmark
This is result with my local desktop with RTX4070.
Benchmarking copy_tensor...
Average compute time: 1.3043880462646485e-05 sec
Benchmarking copy via CPU...
Average compute time: 0.0016725873947143555 sec
Benchmarking dlpack...
Average compute time: 7.467031478881836e-05 sec
copy_tensor
is surprisingly faster than DLPack. Looking at PyTorch's implementation, it seems that PyTorch does additional CUDA stream synchronization, which adds additional compute time.