/Triton-distributed

Distributed Triton for Parallel Systems

Primary LanguageMLIRMIT LicenseMIT

đź‘‹ Hi, everyone!
We are ByteDance Seed team.

You can get to know us better through the following channels👇

seed logo

Triton-distributed

Original Triton README | README in Chinese

Triton-distributed is a distributed compiler designed for computation-communication overlapping, which is based on OpenAI Triton.

Using Triton-distributed, programmers are able to develop efficient kernels comparable to highly-optimized libraries (including Distributed-GEMM and FLUX). Triton-distributed currently mainly targets Nvidia GPU and AMD GPU. It can also be ported to other hardware platforms. Feel free to contact us if you want to use Triton-distributed on your own hardware.

Getting started

Install Triton-distributed from source

Build Guide

How to use Triton-distributed

Triton-distributed provides a set of easy-to use primitives to support the development of distributed compute-communication overlapping kernels. The primitives are divided into low-level primitives and high-level primitives. Currently, we have released our low-level primitives, and we plan to release high-level primitives in future.

Triton-distributed Primitives

Using these primitives, users can program communication kernels easily. For example, a low-latency AllToAll (with better latency than DeepEP for inference) is shown below. The performance of this example on 32 H800 GPUs is 137us (128 tokens per rank, topk=8, hidden_size=7168, dtype=fp8), while DeepEP is 182 us (note: DeepEP doesn't use NVLink for inference).

@triton.jit
def all_to_all_kernel(
    data_src,
    data_dst,
    splits_src,
    splits_dst,
    signal,
    splits_cumsum,
    scale_src,
    scale_dst,
    rank: int,
    call_count: int,
    WITH_SCALE: tl.constexpr,
    WORLD_SIZE: tl.constexpr,
    HIDDEN: tl.constexpr,
    MAX_M: tl.constexpr,
    EXPERTS_PER_RANK: tl.constexpr,
    NUM_TOT_EXPERTS: tl.constexpr,
    ELEMENT_SIZE: tl.constexpr = 2,
    SCALE_ELEMENT_SIZE: tl.constexpr = 4,
):
    pid = tl.program_id(0)
    threadidx = tid(axis=0)

    exp_st = pid * EXPERTS_PER_RANK
    exp_ed = exp_st + EXPERTS_PER_RANK

    m_st = tl.load(splits_cumsum + exp_st)
    m_ed = tl.load(splits_cumsum + exp_ed)
    num_rows_cur_block = m_ed - m_st

    src_off = m_st
    dst_off = rank * MAX_M

    split_src_ptr = splits_src + exp_st
    off0 = exp_st + tl.arange(0, EXPERTS_PER_RANK)
    off1 = exp_st + tl.arange(0, EXPERTS_PER_RANK) + 1
    cumsum_sts = tl.load(splits_cumsum + off0)
    cumsum_eds = tl.load(splits_cumsum + off1)
    tl.store(split_src_ptr + tl.arange(0, EXPERTS_PER_RANK), cumsum_eds - cumsum_sts)

    act_pos = call_count % 2
    data_dst_ptr = data_dst + act_pos * WORLD_SIZE * MAX_M * HIDDEN + dst_off * HIDDEN
    split_dst_ptr = splits_dst + act_pos * NUM_TOT_EXPERTS + rank * EXPERTS_PER_RANK
    signal_ptr = signal + act_pos * WORLD_SIZE + rank

    libshmem_device.putmem_nbi_block(
        data_dst_ptr,
        data_src + src_off * HIDDEN,
        num_rows_cur_block * HIDDEN * ELEMENT_SIZE,
        pid,
    )
    libshmem_device.putmem_nbi_block(
        split_dst_ptr,
        split_src_ptr,
        EXPERTS_PER_RANK * 4,  # now we use `int32` for splits
        pid,
    )
    if WITH_SCALE:
        scale_dst_ptr = scale_dst + act_pos * WORLD_SIZE * MAX_M + dst_off
        libshmem_device.putmem_signal_nbi_block(
            scale_dst_ptr,
            scale_src + src_off,
            num_rows_cur_block * SCALE_ELEMENT_SIZE,
            signal_ptr,
            call_count,
            libshmem_device.NVSHMEM_SIGNAL_SET,
            pid,
        )

    libshmem_device.fence()
    if threadidx == 0:
        if not WITH_SCALE:
            libshmem_device.signal_op(
                signal_ptr,
                call_count,
                libshmem_device.NVSHMEM_SIGNAL_SET,
                pid,
            )
        libshmem_device.signal_wait_until(
            signal + act_pos * WORLD_SIZE + pid,
            libshmem_device.NVSHMEM_CMP_EQ,
            call_count,
        )

Also, users can combine the communication part with computation part to design overlapping kernels. We have provided example implementations in third_party/distributed/distributed/kernels.

Performance

Triton-distributed can achieve comparable or better performance than hand-tuned libraries.

AllGather GEMM on single node of H800x8

Ag-GEMM-inter-node

GEMM ReduceScatter on single node of H800x8

Ag-GEMM-inter-node

AllGather GEMM on 2 nodes of H800x8

Ag-GEMM-inter-node

GEMM ReduceScatter on 2 nodes of H800x8

GEMM-Rs-inter-node

Scaling of Distributed Flash-Decode from 1 GPU to 32 GPUs

The batch size is 1 (one query) for decoding. flash-decode-inter-node

Performance on Other Platforms

AMD GPUs

Roadmaps

Functionalities

  • Release low-level primitives
  • Release high-level primitives
  • Tutorials
  • Pre-built binary

Kernels

  • Release single-node GEMM TP overlapping kernels
  • Release single-node MoE TP overlapping kernels
  • Release single-node distributed Flash-Decoding kernels
  • Release single-node MoE EP overlapping kernels
  • Release cross-node GEMM TP overlapping kernels
  • Release cross-node MoE TP overlapping kernels
  • Release cross-node distributed Flash-Decoding kernels
  • Release cross-node EP all-to-all kernels (similar to DeepEP)
  • Provide tutorials for kernel implementation

Backends

Computation

  • Nvidia SM90a support
  • Nvidia SM80 support
  • Nvidia SM89 support
  • AMD CDNA3 support

Communication

  • NVLink
  • IB
  • PCIe

Performance

  • Performance report

License

The Triton-distributed project is under MIT license. Part of our code is under Apache-2.0 License:

  • third_party/distributed/distributed/kernels/flash_decode.py

Triton's original code is partially under Apache-2.0 License, these files include:

  • include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h
  • lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
  • python/triton/_C/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h
  • utils/generate-test-checks.py

Citation

If you use Triton-distributed in a scientific publication, we encourage you to add the following reference to the related papers:

@misc{zheng2025tilelink,
      title={TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives},
      author={Size Zheng, Jin Fang, Xuegui Zheng, Qi Hou, Wenlei Bao, Ningxin Zheng, Ziheng Jiang, Dongyang Wang, Jianxi Ye, Haibin Lin, Li-Wen Chang, Xin Liu},
      year={2025},
}

Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society.

Discussion and Contribution

Please use issues or pull requests for discussion and contribution (see CONTRIBUTING.md).