NVIDIA/Megatron-LM

[QUESTION] Why megatron-core seems slower and use more gpu mem than legacy for gpt_pretrain?

Opened this issue · 1 comments

Your question
I run pretrain_gpt on same arch, data, training hyperparams and same hardware, with and without using megatron_core when build the model.
I notice clearly worse wall clock time and memory usage:

setting wall clock time per step(ms) mem per gpu(GB)
legacy 630 45
use_mcore 690 63

Environment:

hardware torch version cuda version
A100-80G-PCIe x 4 2.1.2 12.2

For the data I use c4_en data from huggingface and tokenize it using gpt2 tokenizer. I use the first 3.6e7(first 10%) document to conduct the experiments.

To Reproduce
megatron-lm commit hash: 9de386d
I customize a script from pretrain_gpt_distributed.sh and rename it as pretrain_gpt_cli.sh

set -x
#!/bin/bash

# Runs the "345M" parameter model

export CUDA_DEVICE_MAX_CONNECTIONS=1
# dist
GPUS_PER_NODE=4 # TODO: change in future
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

# tokenizer
TOKENIZER_DIR="/root/model_ckps/gpt2"
VOCAB_FILE="${TOKENIZER_DIR}/vocab.json"
MERGE_FILE="${TOKENIZER_DIR}/merges.txt" 


# training hp
MBS=${1:-32}
GBS=${2:-128}
LR=${3:-"1e-4"}
MIN_LR=${4:-"1e-5"}
WARMUP_RATIO=${5:-"0.01"}
WD=${6:-"1e-2"}
SEQ_LEN=${7:-1024}
TRAIN_STEPS=${8:-10000} # related to data?
LR_DECAY_STEPS=${9:-10000} # TODO: how to set this?
LR_SCHED=${10:-"cosine"}
USE_MCORE=${11:-"False"}
# POSITION_EMB? learnable ape by default

# model hp
HIDDEN=${12:-1024}
ATTN_HEADS=${13:-16}
NUM_LAYERS=${14:-24}
ATTN_DROP=${15:-0.0}
HIDDEN_DROP=${16:-0.0}
# ATTN_SOFTMAX_FP32?
# UNTIE?

# output setting
LOG_INTERVAL=${17:-1}
SAVE_INTERVAL=${18:-10000}
EVAL_STEPS=${19:-1000}
# EVAL_INTERVAL=${20:-1000}?

# data setting
DATA_NAME=${20:-"gpt2_10p"}

# misc setting
SEED=${21:-"1234"}

# moe setting
NUM_EXPERTS=${22:-"none"} # if none, no moe
LOAD_BALANCER=${23:-"aux_loss"} 
MOE_TOPK=${24:-2}
AUX_LW=${25:-"1e-2"} # should be tuned

MIX_PREC="bf16"

DATA_PATH=/root/data/tokenized/c4_en_${DATA_NAME}_text_document


EXP_NAME="c4_${DATA_NAME}_bs${GBS}_mbs${MBS}_lr${LR}_mlr${MIN_LR}_wm${WARMUP_RATIO}_wd${WD}"
EXP_NAME="${EXP_NAME}_${SEQ_LEN}_ts${TRAIN_STEPS}_${LR_SCHED}${LR_DECAY_STEPS}_${MIX_PREC}"
EXP_NAME="${EXP_NAME}_h${HIDDEN}_a${ATTN_HEADS}_l${NUM_LAYERS}_ad${ATTN_DROP}_hd${HIDDEN_DROP}"
EXP_NAME="${EXP_NAME}_${SEED}"

CHECKPOINT_PATH="/root/model_ckps/${EXP_NAME}"



DISTRIBUTED_ARGS="
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    --node_rank $NODE_RANK \
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT
"
DATA_ARGS="
    --data-path $DATA_PATH \
    --vocab-file $VOCAB_FILE \
    --merge-file $MERGE_FILE \
    --split 949,50,1
"

OUTPUT_ARGS="
    --log-interval ${LOG_INTERVAL} \
    --tensorboard-log-interval ${LOG_INTERVAL} \
    --save-interval ${SAVE_INTERVAL} \
    --eval-iters ${EVAL_STEPS} \
    --tensorboard-dir ${CHECKPOINT_PATH}/tb
"

MOE_ARGS=""
if [ "$NUM_EXPERTS" != "none" ]; then
    MOE_ARGS="
        --num-experts $NUM_EXPERTS \
        --moe-router-load-balancing-type $LOAD_BALANCER \
        --moe-router-topk $MOE_TOPK \
        --moe-aux-loss-coeff $AUX_LW
    "
    EXP_NAME="${EXP_NAME}_moe${NUM_EXPERTS}-${MOE_TOPK}_${LOAD_BALANCER}${AUX_LW}"
fi

MCORE_ARGS=""
if [ "$USE_MCORE" == "True" ]; then
    MCORE_ARGS="--use-mcore-models"
    EXP_NAME="${EXP_NAME}_mcore"
fi

WANDB_PROJECT="ScalingLaws"

WANDB_ENTITY="reign" \
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
    --micro-batch-size ${MBS} \
    --global-batch-size ${GBS} \
    --lr ${LR} \
    --min-lr ${MIN_LR} \
    --lr-warmup-fraction ${WARMUP_RATIO} \
    --weight-decay ${WD} \
    --seq-length ${SEQ_LEN} \
    --max-position-embeddings ${SEQ_LEN} \
    --train-iters ${TRAIN_STEPS} \
    --lr-decay-iters ${LR_DECAY_STEPS} \
    --lr-decay-style ${LR_SCHED} \
    --${MIX_PREC} \
    --hidden-size ${HIDDEN} \
    --num-attention-heads ${ATTN_HEADS} \
    --num-layers ${NUM_LAYERS} \
    --attention-dropout ${ATTN_DROP} \
    --hidden-dropout ${HIDDEN_DROP} \
    --clip-grad 1.0 \
    $DATA_ARGS \
    $OUTPUT_ARGS \
    --distributed-backend nccl \
    --save $CHECKPOINT_PATH \
    --load $CHECKPOINT_PATH \
    --use-flash-attn \
    --wandb-project $WANDB_PROJECT \
    --wandb-exp-name "$EXP_NAME" \
    --seed $SEED \
    $MOE_ARGS \
    $MCORE_ARGS

To reproduce the experiment, please run following bash command:

STEP=20000
USE_MCORE="True" # or "False" to use legacy
bash examples/pretrain_gpt_cli.sh 64 512 1e-3 1e-5 \
    0.01 0.0 1024 $STEP $STEP cosine $USE_MCORE \
    512 8 8 0.0 0.0 1 $STEP 100 gpt2_10p 1234

Is there any reason behind this?

A possible reason is that the local mcore model does not support flash-attn.

https://github.com/NVIDIA/Megatron-LM/blob/core_v0.6.0/megatron/core/models/gpt/gpt_layer_specs.py#L53