Dispatch mechanism may break when any two libraries that use CUB and/thrust have been compiled for different set of GPU architectures
elstehle opened this issue ยท 6 comments
The following describes a problem observed in more "complex" software projects, where different components (or libraries) use CUB and/or thrust without separating CUB and/or thrust through namespace costumisation. This issue may be observed when linked libraries include CUB and/or thrust - even if the libraries' dependency on CUB and/or thrust is not apparent to the library user.
Is this the issue that I'm having?
If you are:
- linking against another library that is using either CUB and/or thrust and
- (your source files or a second library) are using CUB and/or thrust and
- you are seeing an error like:
- an exception like "
merge_sort: failed on 2nd step: cudaErrorInvalidValue: invalid argument
" - Running your program under
cuda-memcheck
orcompute-sanitizer --tool memcheck
reports out-of-bounds global memory reads or global memory writes (intotemporary_storage
) within a CUB (or thrust kernel) cudaErrorInvalidValue: invalid argument
thrown from a thrust algorithm
- an exception like "
- The issue you're running into is not deterministic. Whether you'll experience a problem or not is determined at load time(?). It may well be that you run your program once and everything works perfectly fine; you can run the affected thrust/CUB algorithm hundreds of times in a loop without any issue. But when you run your program the next time, it will fail (consistently).
The root cause
Situation
- CUB is using tuning policies to determine the optimal "meta parameters" that are most efficient for a kernel on a specific GPU architecture.
- There's a compile-time and a run-time component to the tuning policies. I'll refer to the run-time component as the "dispatch mechanism".
- The compile-time component makes sure that, during a compilation pass for a specific GPU architecture, the kernel is compiled using the correct "meta parameters" for that specific architecture (e.g., getting the correct meta parameters may be implemented using
__CUDA_ARCH__
). Such meta parameters are parameters likeBLOCK_THREADS
(the number of threads per thread block),ITEMS_PER_THREAD
(the number of items processed by each thread), etc. - At run-time, the dispatch mechanism needs to configure the kernel launch of a CUB algorithm. I.e., it needs to configure the correct
block size
(that corresponds to the kernel'sBLOCK_THREADS
) and the correctgrid size
. These run-time parameters need to match the parameters of the kernel that will actually get launched. - To determine the GPU architecture that a kernel will get dispatched for, CUB uses
cudaFuncGetAttributes
oncub::EmptyKernel
to query the closest architecture for whichEmptyKernel
was compiled for, assuming thatEmptyKernel
has been compiled for exactly the same architectures as the kernels actually implementing the various algorithms (which usually is the case).
Problem
CUB's kernels have weak external linkage. All kernels from all translation units being linked will end up in the binary's fatbin. If there's multiple choices for a kernel, the CUDA runtime seems to choose any qualifying kernel candidate "at random".
compilation
nvcc -c -gencode arch=compute_52,code=compute_52 my_lib.cu
nvcc -c -gencode arch=compute_70,code=compute_70 main.cu
nvcc -o sort_test my_lib.o main.o && compute-sanitizer --tool memcheck ./sort_test
my_lib.cu
#include <thrust/scan.h>
void my_lib_scan(cudaStream_t stream)
{
// this can be an arbitrary library
// imagine it uses some thrust algorithms (e.g., a scan)
// and it comes pre-compiled for _some_ GPU architecture
// In this case, just including the header is sufficient for EmptyKernel to be compiled in this TU
}
main.cu
#include <thrust/sort.h>
#include <thrust/device_vector.h>
int main()
{
thrust::device_vector<int> d_vec(128 << 20);
thrust::sort(d_vec.begin(), d_vec.end());
cudaDeviceSynchronize();
std::cout << cudaGetLastError() << "\n";
}
output
Running on a V100
#RUN 0
compute-sanitizer --tool memcheck ./sort_test
========= COMPUTE-SANITIZER
cudaFuncGetAttributes(EmptyKernel): 700
0
DeviceRadixSortHistogramKernel: 700
DeviceRadixSortOnesweepKernel: 700
DeviceRadixSortOnesweepKernel: 700
DeviceRadixSortOnesweepKernel: 700
DeviceRadixSortOnesweepKernel: 700
0
========= ERROR SUMMARY: 0 errors
#RUN 1
compute-sanitizer --tool memcheck ./sort_test
========= COMPUTE-SANITIZER
cudaFuncGetAttributes(EmptyKernel): 520
0
DeviceRadixSortUpsweepKernel: 700
RadixSortScanBinsKernel: 700
RadixSortScanBinsKernel: 700
DeviceRadixSortUpsweepKernel: 700
RadixSortScanBinsKernel: 700
RadixSortScanBinsKernel: 700
DeviceRadixSortUpsweepKernel: 700
========= Invalid __global__ write of size 4 bytes
========= at 0x74d0 in cub/agent/agent_radix_sort_downsweep.cuh:264:void cub::AgentRadixSortDownsweep<cub::AgentRadixSortDownsweepPolicy<(int)512, (int)23, int, (cub::BlockLoadAlgorithm)3, (cub::CacheLoadModifier)0, (cub::RadixRankAlgorithm)2, (cub::BlockScanAlgorithm)2, (int)7, cub::RegBoundScaling<(int)512, (int)23, int>>, (bool)0, int, cub::NullType, unsigned int>::ScatterKeys<(bool)1>(unsigned int (&)[23], unsigned int (&)[23], int (&)[23], unsigned int)
========= by thread (125,0,0) in block (0,0,0)
[...]
Potential Solutions
Declare the CUB kernels static
. Making sure that CUB kernels in translation unit A
won't interfere with the kernels in translation unit B
would be a viable solution. We currently have all the kernels from both translation units in the linked binary anyways. See below cuobjdump
for the above code example.
cuobjdump sort_test -xptx all
Extracting PTX file and ptxas options 1: my_lib.sm_52.ptx -arch=sm_52 --generate-line-info
Extracting PTX file and ptxas options 2: main.sm_70.ptx -arch=sm_70 --generate-line-info
cat my_lib.sm_52.ptx |c++filt|grep .entry
.visible .entry void cub::EmptyKernel<void>()()
cat main.sm_70.ptx |c++filt|grep .entry
.visible .entry void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__uninitialized_fill::functor<thrust::device_ptr<int>, int>, unsigned long>, thrust::cuda_cub::__uninitialized_fill::functor<thrust::device_ptr<int>, int>, unsigned long>(thrust::cuda_cub::__uninitialized_fill::functor<thrust::device_ptr<int>, int>, unsigned long)(
.visible .entry void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<int*, int*, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>, thrust::cuda_cub::__transform::unary_transform_f<int*, int*, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>(thrust::cuda_cub::__transform::unary_transform_f<int*, int*, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long)(
.visible .entry void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<thrust::detail::normal_iterator<thrust::device_ptr<int> >, thrust::detail::normal_iterator<thrust::device_ptr<int> >, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>, thrust::cuda_cub::__transform::unary_transform_f<thrust::detail::normal_iterator<thrust::device_ptr<int> >, thrust::detail::normal_iterator<thrust::device_ptr<int> >, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>(thrust::cuda_cub::__transform::unary_transform_f<thrust::detail::normal_iterator<thrust::device_ptr<int> >, thrust::detail::normal_iterator<thrust::device_ptr<int> >, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long)(
.visible .entry void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<int const*, thrust::device_ptr<int>, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>, thrust::cuda_cub::__transform::unary_transform_f<int const*, thrust::device_ptr<int>, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>(thrust::cuda_cub::__transform::unary_transform_f<int const*, thrust::device_ptr<int>, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long)(
.visible .entry void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<thrust::device_ptr<int>, int*, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>, thrust::cuda_cub::__transform::unary_transform_f<thrust::device_ptr<int>, int*, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>(thrust::cuda_cub::__transform::unary_transform_f<thrust::device_ptr<int>, int*, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long)(
.visible .entry void thrust::cuda_cub::core::_kernel_agent<thrust::cuda_cub::__parallel_for::ParallelForAgent<thrust::cuda_cub::__transform::unary_transform_f<thrust::device_ptr<int>, thrust::device_ptr<int>, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>, thrust::cuda_cub::__transform::unary_transform_f<thrust::device_ptr<int>, thrust::device_ptr<int>, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long>(thrust::cuda_cub::__transform::unary_transform_f<thrust::device_ptr<int>, thrust::device_ptr<int>, thrust::cuda_cub::__transform::no_stencil_tag, thrust::identity<int>, thrust::cuda_cub::__transform::always_true_predicate>, long)(
.visible .entry void cub::EmptyKernel<void>()()
.visible .entry void cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, false, int, cub::NullType, unsigned int>(int const*, int*, cub::NullType const*, cub::NullType*, unsigned int, int, int)(
.visible .entry void cub::DeviceRadixSortUpsweepKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, false, false, int, unsigned int>(int const*, unsigned int*, unsigned int, int, int, cub::GridEvenShare<unsigned int>)(
.visible .entry void cub::DeviceRadixSortUpsweepKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, true, false, int, unsigned int>(int const*, unsigned int*, unsigned int, int, int, cub::GridEvenShare<unsigned int>)(
.visible .entry void cub::RadixSortScanBinsKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, unsigned int>(unsigned int*, int)(
.visible .entry void cub::DeviceRadixSortDownsweepKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, false, false, int, cub::NullType, unsigned int>(int const*, int*, cub::NullType const*, cub::NullType*, unsigned int*, unsigned int, int, int, cub::GridEvenShare<unsigned int>)(
.visible .entry void cub::DeviceRadixSortDownsweepKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, true, false, int, cub::NullType, unsigned int>(int const*, int*, cub::NullType const*, cub::NullType*, unsigned int*, unsigned int, int, int, cub::GridEvenShare<unsigned int>)(
.visible .entry void cub::DeviceRadixSortHistogramKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, false, int, unsigned int>(unsigned int*, int const*, unsigned int, int, int)(
.visible .entry void cub::DeviceRadixSortExclusiveSumKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, unsigned int>(unsigned int*)(
.visible .entry void cub::DeviceRadixSortOnesweepKernel<cub::DeviceRadixSortPolicy<int, cub::NullType, unsigned int>::Policy800, false, int, cub::NullType, unsigned int, int, int>(int*, int*, unsigned int*, unsigned int const*, int*, int const*, cub::NullType*, cub::NullType const*, int, int, int)(
List of issues that may be linked to this root cause:
My only concern regarding static
specifier for kernels is that we'll drastically increase binary size:
#pragma once
#ifdef STATIC
#define SPECIFIER static
#else
#define SPECIFIER
#endif
template <class T>
SPECIFIER __global__ void kernel(){}
I have two TUs that use the same call kernel<int><<<1, 1>>>();
. When compiled with nvcc the result is the same:
:nvcc tu_1.cu tu_2.cu main.cu
:cuobjdump --dump-sass a.out | rg Function
Function : _Z6kernelIiEvv
Function : _Z6kernelIiEvv
:nvcc -DSTATIC tu_1.cu tu_2.cu main.cu
:cuobjdump --dump-sass a.out | rg Function
Function : _Z6kernelIiEvv
Function : _Z6kernelIiEvv
but when you provide -rdc
flag:
:cuobjdump --dump-sass a.out | rg Function
Function : _Z6kernelIiEvv
:nvcc -DSTATIC -rdc=true tu_1.cu tu_2.cu main.cu
:cuobjdump --dump-sass a.out | rg Function
Function : __nv_static_27__91103086_7_tu_1_cu__Z3foov__Z6kernelIiEvv
Function : __nv_static_27__83a59f68_7_tu_2_cu__Z3barv__Z6kernelIiEvv
So we'll have a kernel per each TU in applications that use CUB. Moreover, I believe that rdc
is a default for nvc++:
:nvc++ tu_1.cu tu_2.cu main.cu
:cuobjdump --dump-sass a.out | rg Function
Function : _Z6kernelIiEvv
:nvc++ -DSTATIC tu_1.cu tu_2.cu main.cu
:cuobjdump --dump-sass a.out | rg Function
Function : _ZN27_INTERNAL_7_tu_1_cu__Z3foov6kernelIiEEvv
Function : _ZN27_INTERNAL_7_tu_2_cu__Z3barv6kernelIiEEvv
If there's multiple choices for a kernel, the CUDA runtime seems to choose any qualifying kernel candidate "at random".
Let me make sure I'm following what's going on here.
main.cu
andmy_lib.cu
are compiled with different archs and their object files are linked- Both
main.cu
andmy_lib.cu
end up compilingcub::EmptyKernel
thrust::sort
inmain.cu
invokescudaFuncGetAttributes(..., cub::EmptyKernel)
- We don't know if the
cub::EmptyKernel
we're querying comes frommain.o
ormy_lib.o
- Therefore, the resulting arch from
cudaFuncGetAttributes
is non-deterministic
Is that right?
This piqued my curiosity and I went far down a rabbit hole.
TL;DR: There is something extremely odd going on here that I don't understand and just making the kernel static
does not fix the issue.
I captured my repro and results so far here: https://github.com/jrhemstad/cuda_arch_odr
The only thing that seems to work robustly is to make the linkage of both the kernel and the enclosing function to be internal.
If there's multiple choices for a kernel, the CUDA runtime seems to choose any qualifying kernel candidate "at random".
Let me make sure I'm following what's going on here.
main.cu
andmy_lib.cu
are compiled with different archs and their object files are linked- Both
main.cu
andmy_lib.cu
end up compilingcub::EmptyKernel
thrust::sort
inmain.cu
invokescudaFuncGetAttributes(..., cub::EmptyKernel)
- We don't know if the
cub::EmptyKernel
we're querying comes frommain.o
ormy_lib.o
- Therefore, the resulting arch from
cudaFuncGetAttributes
is non-deterministicIs that right?
That's exactly right.
TL;DR: There is something extremely odd going on here that I don't understand and just making the kernel
static
does not fix the issue.
Thanks for the reproducer and summarising the results. This highlights that we want to be careful and thoroughly verify whichever solution we should identify as a candidate.
In the case of your repro, I believe that test()
needs to have internal linkage too.
Otherwise - and for simplicity, let's assume kernel
has internal linkage - we'll have two test()
candidates: (a) one from a.cu
(which only sees a.cu
's kernel
with sm 5.2) and (b) one from b.cu
(which only sees a.cu
's kernel
with sm 7.0). Apparently, during link time, one of the two test()
implementations "wins" and would provide the "implementation" of test()
in all invocations from a.cu
and b.cu
. inline
apparently only lifts ODR but does not impact linkage.
However, it seems that if there's no ODR-use of the inline test()
in a.cu
, I don't find test()
in the symbol table of a.o
. Which may relate to (source):
For an inline function or inline variable (since C++17), a definition is required in every translation unit where it is odr-used.
This is the reason why declaring the kernel static
was sufficient in my case. my_lib.cu
only caused compilation of the EmptyKernel
, but never actually invoked PtxVersionUncached()
(the equivalent of test()
). Hence, PtxVersionUncached()
never made it to the candidate list at link time to override main.cu
's version of PtxVersionUncached()
:
readelf -sW my_lib.o | awk '$4 == "FUNC"' | c++filt|grep PtxVersion
#<nothing returned>
However, after adding an algorithm invocation to my_lib.cu
, I ran into the same issue as described for test()
in ttps://github.com/jrhemstad/cuda_arch_odr:
readelf -sW my_lib.o | awk '$4 == "FUNC"' | c++filt|grep PtxVersion
1148: 0000000000000000 277 FUNC WEAK DEFAULT 471 cub::PtxVersionUncached(int&)
1152: 0000000000000000 132 FUNC WEAK DEFAULT 473 cub::PtxVersionUncached(int&, int)
1153: 0000000000000000 38 FUNC WEAK DEFAULT 476 cub::PtxVersion(int&)::{lambda(int&)#1}::operator()(int&) const
1154: 0000000000000000 159 FUNC WEAK DEFAULT 478 cub::PtxVersion(int&)
1155: 0000000000000000 119 FUNC WEAK DEFAULT 517 cub::PerDeviceAttributeCache& cub::GetPerDeviceAttributeCache<cub::PtxVersionCacheTag>()
1156: 0000000000000000 361 FUNC WEAK DEFAULT 519 cub::PerDeviceAttributeCache::DevicePayload cub::PerDeviceAttributeCache::operator()<cub::PtxVersion(int&)::{lambda(int&)#1}>(cub::PtxVersion(int&)::{lambda(int&)#1}&&, int)
1214: 0000000000000000 14 FUNC WEAK DEFAULT 554 cub::PtxVersion(int&)::{lambda(int&)#1}&& std::forward<cub::PtxVersion(int&)::{lambda(int&)#1}>(std::remove_reference<cub::PtxVersion(int&)::{lambda(int&)#1}>::type&)
Also, I believe that means that the full call path (e.g., query_ptx()
-> do_query_ptx()
-> cudaFuncGetAttributes(kernel)
) would need to have internal linkage to make sure we're not catching a symbol from another TU along the path(?).
Similarly, we need to be careful about not querying PerDeviceAttributeCache
across TUs.
If you add -cudart shared
to the link lines you also get a different set of results.
Works? | Linker | kernel() annotation | test() | test() anon namespace |
---|---|---|---|---|
Static | static | N | ||
Static | inline | N | ||
Y | Static | static | static | N |
Static | static | inline | N | |
Y | Static | static | Y | |
Static | Y | |||
Dynamic | static | N | ||
Dynamic | inline | N | ||
Y | Dynamic | static | static | N |
Dynamic | static | inline | N | |
Y | Dynamic | static | Y | |
Dynamic | Y |