Question about DeviceSyncer and Ensuring Synchronization
hidva opened this issue · 1 comments
Hi there,
I've been studying the source code of mscclpp DeviceSyncer
and have a question regarding its synchronization capabilities. Specifically, can DeviceSyncer
really ensure synchronization? Here is an example:
# block_0 # block_1
x_0 = 33 x_1 = 77
device_syncer.sync(2) device_syncer.sync(2) #2
read x_1 #1
In the above example, does the device_syncer.sync
at #2
guarantee that x_1 = 77
is visible to #1
? My concern arises because the DeviceSyncer.sync
function seems to involve only some fences and relaxed memory order operations. As stated in the CUDA documentation:
Memory fence functions only affect the ordering of memory operations by a thread; they do not, by themselves, ensure that these memory operations are visible to other threads
I would appreciate any clarification on whether or not DeviceSyncer
can ensure the visibility of memory operations between blocks in this context.
Thank you for your assistance!
I apologize, I forgot that __threadfence
is equivalent to memory_order_seq_cst
! After manually replacing it with memory_order_relaxed
, I indeed observed the problem.
diff --git a/include/mscclpp/concurrency_device.hpp b/include/mscclpp/concurrency_device.hpp
index 6614b91..dc8d636 100644
--- a/include/mscclpp/concurrency_device.hpp
+++ b/include/mscclpp/concurrency_device.hpp
@@ -29,7 +29,8 @@ struct DeviceSyncer {
if (blockNum == 1) return;
if (threadIdx.x == 0) {
// Need a `__threadfence()` before to flip `flag`.
- __threadfence();
+ // __threadfence();
+ cuda::atomic_thread_fence(cuda::memory_order_relaxed, cuda::thread_scope_device);
unsigned int tmp = preFlag_ ^ 1;
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) {
atomicStore(&flag_, tmp, memoryOrderRelaxed);
#include <stdlib.h>
#include <iostream>
#include <mscclpp/concurrency_device.hpp>
#define MSCCLPP_CUDATHROW(cmd) \
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
std::cerr << #cmd << cudaGetErrorString(err) << std::endl;\
abort(); \
} \
} while (false)
__device__ unsigned get_smid(void) {
unsigned ret = 0;
asm("mov.u32 %0, %smid;" : "=r"(ret));
return ret;
}
__managed__ unsigned int bad_apple_cnt = 0;
unsigned int sm1_var_expected = 33;
unsigned int sm2_var_expected = 77;
#define SM2_VAR_EXPECTED 77
#define SM1_VAR_EXPECTED 33
__device__ mscclpp::DeviceSyncer dev_syncer;
__device__ unsigned int sm1_var = SM1_VAR_EXPECTED;
__device__ void store_sm1_var(unsigned int val) {
asm("st.weak.global.wb.u32 [sm1_var], %0;" :: "r"(val));
}
__device__ unsigned int sm2_var = SM2_VAR_EXPECTED;
__device__ void store_sm2_var(unsigned int val) {
asm("st.weak.global.wb.u32 [sm2_var], %0;" :: "r"(val));
}
__device__ unsigned load_sm1_var() {
unsigned ret = 0;
asm("ld.weak.global.ca.u32 %0, [sm1_var];" : "=r"(ret));
return ret;
}
__device__ unsigned load_sm2_var() {
unsigned ret = 0;
asm("ld.weak.global.ca.u32 %0, [sm2_var];" : "=r"(ret));
return ret;
}
__global__ void zy_sync_test()
{
unsigned cur_smid = get_smid();
if (load_sm1_var() != SM1_VAR_EXPECTED) {
__brkpt();
}
if (load_sm2_var() != SM2_VAR_EXPECTED) {
__brkpt();
}
if (threadIdx.x == 0) {
//printf("cur_smid=%u\n", cur_smid);
if ((cur_smid - 124) == 0) {
// sm1_var = SM2_VAR_EXPECTED;
store_sm1_var(SM2_VAR_EXPECTED);
} else {
// sm2_var = SM1_VAR_EXPECTED;
store_sm2_var(SM1_VAR_EXPECTED);
}
}
// __syncthreads();
//__threadfence();
dev_syncer.sync(gridDim.x);
if (threadIdx.x == 0) {
if ((cur_smid - 124) == 0) {
if (load_sm2_var() == SM2_VAR_EXPECTED) {
atomicAdd(&bad_apple_cnt, 1);
// __brkpt();
}
} else {
if (load_sm1_var() == SM1_VAR_EXPECTED) {
atomicAdd(&bad_apple_cnt, 1);
// __brkpt();
}
}
}
__syncthreads();
}
// Host code
int main(int argc, char** argv)
{
int device;
cudaDeviceProp prop;
cudaSetDevice(3);
cudaGetDevice(&device);
cudaGetDeviceProperties(&prop, device);
int max_sm = 0;
int num_blks = 0;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blks,
zy_sync_test,
1024,
max_sm);
std::cout << "device=" << device << ",maxThreadsPerMultiProcessor=" << prop.maxThreadsPerMultiProcessor
<< ",sharedMemPerMultiprocessor=" << prop.sharedMemPerMultiprocessor
<< ",max_sm=" << max_sm
<< ",num_blks=" << num_blks << std::endl;
mscclpp::DeviceSyncer syncer = {};
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(dev_syncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
unsigned long i = 0;
while (true) {
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(sm1_var, &sm1_var_expected, sizeof(sm1_var_expected)));
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(sm2_var, &sm2_var_expected, sizeof(sm2_var_expected)));
zy_sync_test<<<4, 1024, max_sm>>>();
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
MSCCLPP_CUDATHROW(cudaGetLastError());
++i;
if (i % 10000== 0) {
std::cout << "Do it again!" << i << ", bad_apple=" << bad_apple_cnt << std::endl;
}
}
return 0;
}