pytorch/examples

SequenceParallel sharding seems wrong

marib00 opened this issue · 1 comments

Context

  • Pytorch version: 2.5 (nightly)
  • Operating System and version: Ubunu 24.04 LTS

Your Environment

  • Installed using source? [no]:
  • Are you planning to deploy it using docker container? [no]:
  • Is it a CPU or GPU environment?: GPU

Expected vs Current Behavior

According to the documentation torch.distributed.tensor.parallel.SequenceParallel should shard on the sequence dimension i.e. [B, T, C] -> [B, T//_world_size, C] but it seems to be tiling instead i.e. [B, T, C] -> [B, T*_world_size, C].

Steps to Reproduce

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput

device = "cuda"
_world_size = int(os.environ["WORLD_SIZE"])
device_mesh = init_device_mesh(device_type=device, mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.l1 = nn.Linear(n_embd, 12)
        self.l2 = nn.Linear(12, 11)

    def forward(self, x):
        x0 = self.ln1(x) # inside LayerNorm.forward the DTensor shape is (B, T*_world_size, n_embed), I'd expect (B, T/_world_size, n_embed) instead!
        x1 = self.l1(x0)
        x2 = F.relu(x1)
        x3 = self.l2(x2)
        return x3

B,T,n_embd = 64, 16, 1024

data = torch.randn(B,T,n_embd).to(device)
model = MyModel().to(device)

model = parallelize_module(
    model,
    device_mesh=device_mesh,
    parallelize_plan={
        "ln1": SequenceParallel(use_local_output=False),
        "l1": PrepareModuleInput(
            input_layouts=(Shard(1),),
            desired_input_layouts=(Replicate(),), 
            ),
        "l1": ColwiseParallel(),
        "l2": RowwiseParallel(),
    }
)

out = model(data)
destroy_process_group()

Closing because meant to open in pytorch/pytorch 🤦‍♂️

pytorch/pytorch#129355