hpcaitech/ColossalAI

[BUG]: Colossalai-OpenMoE example failed to converge

marsggbo opened this issue ยท 2 comments

๐Ÿ› Describe the bug

Setup

I am currently running the Colossalai/examples/language/openmoe project with the following experimental setup:

  • datasets: load_dataset("yizhongw/self_instruct", "super_natural_instructions"), I also tried "wikitext-2"
  • model: openmoe-base
  • batch size: 2
  • 2 GPUs
  • parallel strategy:
    • pp_size=1
    • dp_size=1
    • ep_size=2
    • extra_dp_size=2
    • zero_stage=2

Issue 1: loss value is very large and cannot converge

I've encountered challenges during the training process where the convergence seems unachievable. The training loss value persists at an exceptionally high level (exceeding $1e10$) even after running 10 epochs. The logged information provided below showcases three final loss values: aux_loss, z_loss, and ZCrossEntropy(ce) loss. While the first two loss items appear normal, the ce loss is significantly larger.

image

Furthermore, in the default setup of the openmoe project, the data format is bf16. Considering this, I suspect the issue might stem from an overflow problem. Consequently, I attempted using fp16 (fp32 seems not supported in zero mode) but still encountered the same problem.

issue 2: when disable loading checkpoint, the loss value will be nan

Additionally, upon commenting out these two lines, the loss value tends to become nan. I wonder what datasets the provided pretrained weights (huggingface hpcaitech/openmoe-base) are based on?

if not test_mode:
load_ckpt(repo_name, model, booster)

image

Are there specific aspects I should pay attention to or common pitfalls that might cause model training failures when working with openmoe? Your insights or suggestions on resolving this issue or optimizing the project would be immensely helpful. Thank you for your assistance!

Environment

  • torch 2.0.1+cu118
  • python 3.8.12

Thank you for your feedback; we will address this issue as soon as possible.

I also encountered this problem