microsoft/mscclpp

[Bug] Hang when one GPU get a new BasePtr while others not.

FC-Li opened this issue · 1 comments

FC-Li commented

The smChans cache policy is using an unordered map with (base_ptr, num_bytes) as key and smChans as value. So when some GPUs determine to create a new channel and insert into the map while other GPUs find the (base_ptr, num_bytes) in cache, a hang will happen.

Please see the code and comments down below.

auto create_channel(void *buff, size_t num_bytes, cudaStream_t stream)
{
        size_t recvBytes;
        CUdeviceptr recvBasePtr;
        MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)buff));

        ChannelKey channel_key{(void*)recvBasePtr, recvBytes};
        auto iter = channels_.find(channel_key);

        if (iter == channels_.end())
        {
            // Some GPUs enter this branch
            auto smChans = p_.comm->createSmChans(buff, num_bytes, nullptr, 0);  // All GPUs need to rendezvous here, so a hang happens.

            using Handle = typename mscclpp::DeviceHandle<mscclpp::SmChannel>;
            Handle *d_smChans;
            cudaMallocAsync(&d_smChans, sizeof(Handle) * smChans.size(), stream);
            std::shared_ptr<Handle> SmChanDeviceHandleDevicePtr{d_smChans, mscclpp::CudaDeleter<Handle>{}};
            std::shared_ptr<std::vector<Handle>> SmChanDeviceHandleHostPtr = MscclppCommunicator::getSmChanDeviceHandleAsync(d_smChans, smChans, stream);

            Channel channel{
                .smChans = smChans,
                .SmChanDeviceHandleDevicePtr = SmChanDeviceHandleDevicePtr,
                .SmChanDeviceHandleHostPtr = SmChanDeviceHandleHostPtr};
            channels_.insert({channel_key, channel});
            return channel;
        }
        else
        {
            // Others enter this branch
            return iter->second;
        }
    }

std::vector<mscclpp::SmChannel> createSmChans(void *inputBuf, size_t inputBufBytes, void *outputBuf, size_t outputBufBytes)
    {
        mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuf, inputBufBytes, mscclpp::Transport::CudaIpc);
        mscclpp::RegisteredMemory outputBufRegMem;
        if (outputBufBytes)
        {
            outputBufRegMem = comm_->registerMemory(outputBuf, outputBufBytes, mscclpp::Transport::CudaIpc);
        }

        std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
        for (int r = 0; r < p_.ep_world_size; r++)
        {
            comm_->sendMemoryOnSetup(outputBufBytes ? outputBufRegMem : inputBufRegMem, r, 0);
            auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
            remoteRegMemories.push_back(remoteMemory);
        }
        comm_->setup();

        std::vector<mscclpp::SmChannel> smChans;

        for (size_t channel_id = 0; channel_id < NUM_CHANNELS_PER_CONNECTION; ++channel_id)
        {
            for (size_t connect_id = 0; connect_id < connections_.size(); ++connect_id)
            {
                smChans.emplace_back(smSemaphores_[connect_id][channel_id], remoteRegMemories[connect_id].get(),
                                     inputBufRegMem.data(),
                                     outputBuf);
            }
        }

        return smChans;
    }

I wonder if there is a way to solve this issue.

This is a known issue and we will fix in a future version. We need to add more communication kernels to support this.
https://github.com/microsoft/mscclpp/blob/main/docs/design/nccl-over-mscclpp.md#limitations