keras-team/keras-nlp

Question about Gemma tensor parallel sharding policy

Closed this issue · 5 comments

Thanks for Gemma model implementation. I found that the layout_map in GemmaBackbone.get_layout_map() seems to show that it has completely opposite sharding policy compared to the typical TP sharding policy used in Transformer architecture.

Typical:

  • embedding: embedding matrix is sharded along vocab_size axis (for matrix of shape [vocab_size, hidden_dim], along axis=0);
  • attention:
    parallelism-tp-parallel_self_attention
    • query|key|value dense kernel are sharded along the column (for kernel of shape [hidden_dim, hidden_dim], along axis=1);
    • output dense kernel is sharded along the row (for kernel of shape [hidden_dim, hidden_dim], along axis=0);
  • feedforward:
    parallelism-tp-parallel_shard_processing
    • gating dense (the first dense) kernel is sharded along the column (for kernel of shape [hidden_dim, intermediate_dim], along axis=1);
    • output dense (the second dense) kernel is sharded along the row (for kernel of shape [intermediate_dim, hidden_dim], along axis=0);

Gemma:

  • embedding: layout_map["token_embedding/embeddings"] = (None, model_dim), seems to be sharded along hidden_dim axis;
  • attention:
    • query|key|value dense: layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (None, model_dim, None), seems to be sharded along the row except the num_heads axis.
    • output dense: layout_map["decoder_block.*attention_output.*kernel"] = (None, None, model_dim), seems to be sharded along the column except the num_heads axis.
  • feedforward:
    • the first dense: layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None), seems to be sharded along the row.
    • the second dense: layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim), seems to be sharded along the column.

Is my understanding correct? If they are opposite, can you please explain the reason?

@qlzh727 for thoughts.

Thanks for the reporting of the issue. Can you share more reference of the "typical" sharding/layout here? I would like to take a closer look for that.

@qlzh727 Thanks for the reply. The typical tensor parallel sharding policy I provided comes from the implementation of Megatron-LM and the documentation of HuggingFace Transformers:

Thanks for the information.

I think in general they are just different ways to shard the tensor/weights, especially for different conditions.

In your approach, it is doing matmul without allgather for qkv and do the collective afterwards (at dotprod of qk and softmax) because your qkv are sharded. Whereas the current Keras implementation will do collective at qkv matmul (since the contrast dimension is sharded), and avoid the collective afterward. It also depends on the cost of collectives (network connection) vs the local computation speed, as well as whether this model is just for prediction or it need finetune and weights update.

I did some benchmark for this and the results are show below. I think your setting does have advantage for the finetune use case. I am testing this on a TPU v3-8 setting. Feel free to provide more benchmark result with GPU testing as well.

(Smaller value are better)

===================
base line (current setting):
generate: 1342 ms per 100 token
finetune with lora: 125ms/step

=====================
Your setting
generate: 1501 ms per 100 token
finetune with lora: 77ms/step

Should be addressed by #1491