microsoft/mscclpp

[Bug] Proxy channel over CudaIPC on AMD GPUs

liangyuRain opened this issue · 2 comments

Hi, we have a simple code that runs successfully on NVIDIA A100 GPUs but fails on AMD MI250/MI210. The code uses proxy channel over CudaIPC connection. We wonder if proxy channel is still buggy on AMD GPUs.

The code does the following:

  • Rank 0: Intialize a cp.zeros data buffer and a cp.ones scratch buffer. Perform data += scratch, and then putWithSignal the data buffer into rank 1's data buffer through proxy channel.
  • Rank 1: Intialize a cp.zeros data buffer and wait for rank 0's signals.

The data is in total 1024*256 ints and is reduce and send in 4*blockDim.x chunks. We launch only one 128-thread threadblock. After the code finishes, both rank 0 and 1 should have data buffer being all ones. However, we noticed that rank 1's data buffer data[66048...196607] is still all zeros, while rank 0's data buffer is correctly all ones.

There is a __syncthreads() after the reduce and a check if the data[start] == 1 never fails. Somehow the putWithSignal afterwards still put all zero data for some chunks. The same program runs successfully on NVIDIA GPUs or with smaller nelem_total on AMD GPUs. We wonder if there is a bug with proxy channel or rocm.

Platform info:

  • ROCm 6.2.1
  • Tested on MI250 and MI210
  • MSCCLPP commit 863a59936084b0dd88c221185841b8c773d17446
  • CuPy commit 0188dd8b16938fa835bcda797f70f9af2f8b4980

Kernel code

#include <mscclpp/proxy_channel_device.hpp>

const uint64_t nelem_total = 1024 * 256;
const uint64_t nint4_total = nelem_total / 4;

extern "C" __global__ void __launch_bounds__(1024)
    reduce_send_kernel(mscclpp::ProxyChannelDeviceHandle* send_proxy_channel,
                       int4 *scratch, int4 *data) {
    const int tid = threadIdx.x;

    const uint64_t nint4_per_send = blockDim.x;
    for (int i = 0; i < nint4_total / nint4_per_send; ++i) {
        const uint64_t start = i * nint4_per_send;
        __syncthreads();
        // Reduce the scartch buffer into the data buffer. Making data buffer all 1s.
        for (uint64_t offset = tid; offset < nint4_per_send; offset += blockDim.x) {
            int4 tmp = data[start + offset];
            int4 val = scratch[start + offset];
            tmp.x += val.x;
            tmp.y += val.y;
            tmp.z += val.z;
            tmp.w += val.w;
            data[start + offset] = tmp;
        }
        __syncthreads();
        // The printf never got triggered.
        if (tid == 0 && data[start].x != 1) 
            printf("Incorrect value %d at %llu\n", data[start].x, start * 4);
        __syncthreads();
        // After reduce, put the reduced data into rank 1's data buffer.
        if (tid == 0) 
            send_proxy_channel->putWithSignal(start * sizeof(int4), nint4_per_send * sizeof(int4));
        __syncthreads();
    }
}


extern "C" __global__ void __launch_bounds__(1024)
    recv_kernel(mscclpp::ProxyChannelDeviceHandle* recv_proxy_channel, int4 *data) {
    const int tid = threadIdx.x;

    const uint64_t nint4_per_send = blockDim.x;
    for (int i = 0; i < nint4_total / nint4_per_send; ++i) {
        if (tid == 0) recv_proxy_channel->wait();
        __syncthreads();
    }
}

Python code

import cupy as cp
import os
import struct

import mscclpp.comm as mscclpp_comm
from mscclpp import (
    ProxyService,
    Transport,
)
from mscclpp_mpi import MpiGroup, mpi_group
from mscclpp.utils import KernelBuilder


def create_group_and_connection(mpi_group: MpiGroup):
    group = mscclpp_comm.CommGroup(mpi_group.comm)
    remote_nghrs = list(range(group.nranks))
    remote_nghrs.remove(group.my_rank)
    connections = group.make_connection(remote_nghrs, Transport.CudaIpc)
    return group, connections

    
def main():
    # MPI group of 2
    mpi_group = MpiGroup(list(range(2)))
    nelem_total = 1024 * 256
    group, connections = create_group_and_connection(mpi_group)
    proxy_service = ProxyService()

    # Data buffer is all 0s
    memory = cp.zeros(nelem_total, dtype=cp.int32)

    peer = 1 - group.my_rank
    # Proxy channel over Transport.CudaIpc
    proxy_channel = group.make_proxy_channels(proxy_service, memory, connections)[peer]
    proxy_arr = cp.asarray(memoryview(proxy_channel.device_handle().raw), dtype=cp.uint8)
    proxy_ptr = proxy_arr.data.ptr

    if group.my_rank == 0:
        # Scartch buffer is all 1s
        scratch = cp.ones(nelem_total, dtype=cp.int32)

        file_dir = os.path.dirname(os.path.abspath(__file__))
        bug_kernel = KernelBuilder(
            file="reproduce_bug.cu",
            kernel_name="reduce_send_kernel",
            file_dir=file_dir,
        ).get_compiled_kernel()
        params = struct.pack("P", proxy_ptr) + struct.pack("P", scratch.data.ptr) + struct.pack("P", memory.data.ptr)
    else:
        file_dir = os.path.dirname(os.path.abspath(__file__))
        bug_kernel = KernelBuilder(
            file="reproduce_bug.cu",
            kernel_name="recv_kernel",
            file_dir=file_dir,
        ).get_compiled_kernel()
        params = struct.pack("P", proxy_ptr) + struct.pack("P", memory.data.ptr)

    proxy_service.start_proxy()
    group.barrier()

    # Launch one 128-thread threadblock
    bug_kernel.launch_kernel(
        params=params,
        nblocks=1,
        nthreads=128,
        shared=0,
        stream=None
    )
    cp.cuda.runtime.deviceSynchronize()

    expected = cp.ones(nelem_total, dtype=cp.int32)
    print(cp.nonzero(memory - expected))
    assert cp.array_equal(memory, expected)

    proxy_service.stop_proxy()
    

if __name__ == "__main__" :
    main()

Output

$ rm -rf __pycache__ ; mpirun -np 2 --tag-output python3 reproduce_bug.py
[1,0]<stdout>:(array([], dtype=int64),)
[1,1]<stdout>:(array([ 66048,  66049,  66050, ..., 196605, 196606, 196607]),)
[1,1]<stderr>:Traceback (most recent call last):
[1,1]<stderr>:  File ".../reproduce_bug.py", line 79, in <module>
[1,1]<stderr>:    main()
[1,1]<stderr>:  File ".../reproduce_bug.py", line 73, in main
[1,1]<stderr>:    assert cp.array_equal(memory, expected)

You need to use uncached GPU buffers for communication between AMD GPUs. @SreevatsaAnantharamu We need a Python binding of the ext alloc function.

@liangyuRain Please try out #423 and use mscclpp.utils.GpuBuffer instead of CuPy arrays where your communication happens on. The usage is the same as CuPy ndarrays. We still cannot gurantee it can work, because CuPy doesn't officially support ROCm 6.x yet and we haven't tried CuPy with AMD MI250/MI210. Please let us know if you still encounter any issues.