[BUG] `no_nvlink` branch failed to compile
Closed this issue · 6 comments
Describe the bug
After adding IntraNodePcie
metas, the RS kernel failed to compile.
To Reproduce
- Add
IntraNodePcie
metas ingemm_v2_reduce_scatter.hpp
:
make_reduce_scatter_meta(_False{}, _IntraNodePcie{}),
make_reduce_scatter_meta(_True{}, _IntraNodePcie{}),
- Program failed to compile.
Expected behavior
Program should be able to compile.
Stack trace/logs
src/reduce_scatter/gemm_v2_reduce_scatter.hpp(352): error: too many initializer values
args.world_size,
^
detected during:
instantiation of "auto bytedance::flux::GemmV2ReduceScatter<GemmMetaT, GemmHParamsT>::to_gemm_args_impl(const bytedance::flux::GemmReduceScatterArguments &, void *) const [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>]" at line 483
instantiation of "auto bytedance::flux::GemmV2ReduceScatter<GemmMetaT, GemmHParamsT>::to_gemm_args(const std::any &, void *) const [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>]" at line 360 of /home/zihuaw/data/com.github/bytedance/flux-container-rebuild-new/include/flux/cuda/gemm_impls/gemm_v2_impl.hpp
instantiation of "auto bytedance::flux::GemmV2Impl<GemmMetaT, GemmHParamsT, DerivedImpl>::to_gemm_args(const std::any &, void *) const [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, DerivedImpl=bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>]" at line 68 of /home/zihuaw/data/com.github/bytedance/flux-container-rebuild-new/include/flux/cuda/gemm_impls/gemm_operator_base_default_impl.hpp
instantiation of "void bytedance::flux::GemmOperatorBaseDefaultImplMixin<DerivedImpl>::run(const std::any &, void *, void *) [with DerivedImpl=bytedance::flux::GemmV2Impl<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
instantiation of class "bytedance::flux::GemmOperatorBaseDefaultImplMixin<DerivedImpl> [with DerivedImpl=bytedance::flux::GemmV2Impl<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
instantiation of class "bytedance::flux::GemmOperatorBaseDefaultImplMixin<DerivedImpl> [with DerivedImpl=bytedance::flux::GemmV2Impl<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
instantiation of class "bytedance::flux::GemmV2Impl<GemmMetaT, GemmHParamsT, DerivedImpl> [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>, DerivedImpl=bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
instantiation of class "bytedance::flux::GemmV2ReduceScatter<GemmMetaT, GemmHParamsT> [with GemmMetaT=bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, GemmHParamsT=bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>]" at line 962 of /usr/include/c++/11/bits/unique_ptr.h
instantiation of "std::_MakeUniq<_Tp>::__single_object std::make_unique<_Tp,_Args...>(_Args &&...) [with _Tp=bytedance::flux::GemmV2ReduceScatter<bytedance::flux::GemmMeta<bytedance::flux::_FP16, bytedance::flux::_Sm89, bytedance::flux::_ReduceScatter, bytedance::flux::_RRR, bytedance::flux::_GemmV2, bytedance::flux::None, bytedance::flux::ReduceScatterMeta<bytedance::flux::_False, bytedance::flux::_IntraNodePcie>>, bytedance::flux::GemmHParams<bytedance::flux::GemmV2HParams<cute::tuple<cute::_64, cute::_64, cute::_32>, cute::tuple<cute::_16, cute::_8, cute::_16>, bytedance::flux::_StreamkSK>, bytedance::flux::None, cute::tuple<cute::_128, cute::_256, cute::_32>, bytedance::flux::_GemmStreamK, cute::_4, bytedance::flux::_RasterAlongN>>, _Args=<>]"
Environment
Within the nvcr.io/nvidia/pytorch:24.07-py3
container:
- CUDA: 12.5.82
- GCC: 11.4.0
- Python: 3.10.12
- PyTorch: 2.4.0 (custom build)
Additional context
After I removed the no_nvlink
branch (
flux/src/reduce_scatter/gemm_v2_reduce_scatter.hpp
Lines 317 to 328 in 239a13d
@lucifer1004 Thanks for your interests.
I don't think we have released the no_nvlink/pcie support yet.
What is your tested GPU?
I tested on Ada GPUs. Sm89_{}
has been added and all tests can run when no_nvlink
is not modified to false
.
@lucifer1004 seems like you add your own sm89 implementation and try to make it work?
We only release sm89 gemm kernel, but not the AG/RS sm89 support.
Actually no implementation, just a single configuration change. But that is not the point here, because this no_nvlink
issue happens at the compilation stage with both --arch 89
and --arch 80
.
We don't release PCIe support yet, so either 80 or 89 with PCIe interconnect is not suppose to work IIRC.
If no_nvlink
branch get hit, that means your test machine is PCIe based, remove that branch does not make it work on PCIe machine.