awslabs/slapo

[Bug] Sharing embedding weights with last linear layer

zarzen opened this issue · 1 comments

Some of the models share weights of the first embedding layer with the last linear layer.
This requires both weights of the embedding weights and last linear layer are sharded in the same way.
Currently, the GPT schedule only shards the embedding weights using schedule at here: https://github.com/awslabs/slapo/blob/main/examples/gpt/schedule.py#L206

If the weights of last linear layer is not sharded, then, even with the tied weight analysis, we cannot synchronize the sharded embedding and last linear layer because of the shape mismatch:

  • The tied weight analysis could tell "m.0.embedding.weights" and "m.{N}.linear.weights" are tied/shared.
  • The "m.0.embedding.weights" is sharded
  • The "m.{N}.linear.weights" should also be sharded.

BTW, if we use parallel vocab logit as loss function, even without sharing the sharding of last linear weights is needed.

@zarzen #15 should fix this. Please check and close this issue if so.