SequenceParallel sharding seems wrong
marib00 opened this issue · 1 comments
marib00 commented
Context
- Pytorch version: 2.5 (nightly)
- Operating System and version: Ubunu 24.04 LTS
Your Environment
- Installed using source? [no]:
- Are you planning to deploy it using docker container? [no]:
- Is it a CPU or GPU environment?: GPU
Expected vs Current Behavior
According to the documentation torch.distributed.tensor.parallel.SequenceParallel
should shard on the sequence dimension i.e. [B, T, C] -> [B, T//_world_size, C]
but it seems to be tiling instead i.e. [B, T, C] -> [B, T*_world_size, C].
Steps to Reproduce
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput
device = "cuda"
_world_size = int(os.environ["WORLD_SIZE"])
device_mesh = init_device_mesh(device_type=device, mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.l1 = nn.Linear(n_embd, 12)
self.l2 = nn.Linear(12, 11)
def forward(self, x):
x0 = self.ln1(x) # inside LayerNorm.forward the DTensor shape is (B, T*_world_size, n_embed), I'd expect (B, T/_world_size, n_embed) instead!
x1 = self.l1(x0)
x2 = F.relu(x1)
x3 = self.l2(x2)
return x3
B,T,n_embd = 64, 16, 1024
data = torch.randn(B,T,n_embd).to(device)
model = MyModel().to(device)
model = parallelize_module(
model,
device_mesh=device_mesh,
parallelize_plan={
"ln1": SequenceParallel(use_local_output=False),
"l1": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"l1": ColwiseParallel(),
"l2": RowwiseParallel(),
}
)
out = model(data)
destroy_process_group()
marib00 commented
Closing because meant to open in pytorch/pytorch 🤦♂️