bytedance/flux

[BUG] `no_nvlink` branch failed to compile

Closed this issue · 6 comments

Describe the bug
After adding IntraNodePcie metas, the RS kernel failed to compile.

To Reproduce

  1. Add IntraNodePcie metas in gemm_v2_reduce_scatter.hpp:
          make_reduce_scatter_meta(_False{}, _IntraNodePcie{}),
          make_reduce_scatter_meta(_True{}, _IntraNodePcie{}),
  1. Program failed to compile.

Expected behavior
Program should be able to compile.

Stack trace/logs

src/reduce_scatter/gemm_v2_reduce_scatter.hpp(352): error: too many initializer values
                 args.world_size,
                 ^
          detected during:
            instantiation of "auto bytedance::flux::GemmV2ReduceScatter<GemmMetaT, GemmHParamsT>::to_gemm_args_impl(const bytedance::flux::GemmReduceScatterArguments &, void *) const [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>]" at line 483
            instantiation of "auto bytedance::flux::GemmV2ReduceScatter<GemmMetaT, GemmHParamsT>::to_gemm_args(const std::any &, void *) const [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>]" at line 360 of /home/zihuaw/data/com.github/bytedance/flux-container-rebuild-new/include/flux/cuda/gemm_impls/gemm_v2_impl.hpp
            instantiation of "auto bytedance::flux::GemmV2Impl<GemmMetaT, GemmHParamsT, DerivedImpl>::to_gemm_args(const std::any &, void *) const [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, DerivedImpl=bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>]" at line 68 of /home/zihuaw/data/com.github/bytedance/flux-container-rebuild-new/include/flux/cuda/gemm_impls/gemm_operator_base_default_impl.hpp
            instantiation of "void bytedance::flux::GemmOperatorBaseDefaultImplMixin<DerivedImpl>::run(const std::any &, void *, void *) [with DerivedImpl=bytedance::flux::GemmV2Impl<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
            instantiation of class "bytedance::flux::GemmOperatorBaseDefaultImplMixin<DerivedImpl> [with DerivedImpl=bytedance::flux::GemmV2Impl<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
            instantiation of class "bytedance::flux::GemmOperatorBaseDefaultImplMixin<DerivedImpl> [with DerivedImpl=bytedance::flux::GemmV2Impl<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
            instantiation of class "bytedance::flux::GemmV2Impl<GemmMetaT, GemmHParamsT, DerivedImpl> [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, DerivedImpl=bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
            instantiation of class "bytedance::flux::GemmV2ReduceScatter<GemmMetaT, GemmHParamsT> [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
            instantiation of "std::_MakeUniq<_Tp>::__single_object std::make_unique<_Tp,_Args...>(_Args &&...) [with _Tp=bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>, _Args=<>]"

Environment

Within the nvcr.io/nvidia/pytorch:24.07-py3 container:

  • CUDA: 12.5.82
  • GCC: 11.4.0
  • Python: 3.10.12
  • PyTorch: 2.4.0 (custom build)

Additional context
After I removed the no_nvlink branch (

if constexpr (no_nvlink) {
return EvtDArgumentType{
{args.beta}, // Beta
{ptr_C,
typename Base::ElementC(0),
stride_C,
args.world_size,
args.reduce_scatter_args.use_gemmk}, // C
{{{args.alpha}}, {}, {}}, // compute0 args
{} // compute 1
};
} else {
), the program could compile, but then the GEMM+RS tests failed with mismatched results when TP size is larger than 1.

@lucifer1004 Thanks for your interests.
I don't think we have released the no_nvlink/pcie support yet.

What is your tested GPU?

I tested on Ada GPUs. Sm89_{} has been added and all tests can run when no_nvlink is not modified to false.

@lucifer1004 seems like you add your own sm89 implementation and try to make it work?
We only release sm89 gemm kernel, but not the AG/RS sm89 support.

Actually no implementation, just a single configuration change. But that is not the point here, because this no_nvlink issue happens at the compilation stage with both --arch 89 and --arch 80.

We don't release PCIe support yet, so either 80 or 89 with PCIe interconnect is not suppose to work IIRC.
If no_nvlink branch get hit, that means your test machine is PCIe based, remove that branch does not make it work on PCIe machine.