mosaicml/llm-foundry

Fine-tune dbrx-instruct on a single VM with 8 H100s

hosseinsarshar opened this issue · 1 comments

I'm trying to fine-tune dbrx on a single machine with 8 H100 gpus. I keep getting OOM error with different configurations, I wonder if this is even doable.

I see a note that suggests 64x80 GPUs, but I wonder if there are ways to do it with 8X80 GB gpus.

Thanks

Some other info about my job:

My config:

# Note: This requires ~64x80GB GPUs
max_seq_len: 64
icl_seq_len: 64

# Run Name
run_name:  # If left blank, will be read from env var $RUN_NAME

# Model
model:
  name: hf_causal_lm
  pretrained: true
  init_device: mixed
  use_auth_token: true
  config_overrides: {}
  use_flash_attention_2: true
  pretrained_model_name_or_path: databricks/dbrx-instruct

# Tokenizer
tokenizer:
  name: databricks/dbrx-instruct
  kwargs:
    model_max_length: ${max_seq_len}
    trust_remote_code: true

# Dataloaders
train_loader:
  name: finetuning
  dataset:
    split: train
    hf_name: mosaicml/dolly_hhrlhf
    shuffle: true
    max_seq_len: ${max_seq_len}
    eos_token_id: 0
    packing_ratio: auto
    allow_pad_trimming: false
    decoder_only_format: true
  drop_last: true
  pin_memory: true
  num_workers: 8
  prefetch_factor: 2
  persistent_workers: true

eval_loader:
  name: finetuning
  dataset:
    split: test
    hf_name: mosaicml/dolly_hhrlhf
    shuffle: false
    max_seq_len: ${max_seq_len}
    packing_ratio: null
    allow_pad_trimming: false
    decoder_only_format: true
  drop_last: true
  pin_memory: true
  num_workers: 8
  prefetch_factor: 2
  persistent_workers: true

# Optimization
optimizer:
  lr: 0.000001
  name: decoupled_lionw
  betas:
  - 0.9
  - 0.95
  weight_decay: 1.0e-06

scheduler:
  name: cosine_with_warmup
  alpha_f: 0
  t_warmup: 0.02dur

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1

max_duration: 2ep
eval_interval: 1ep
global_train_batch_size: 8
eval_first: false
# eval_subset_num_batches: -1

# System
seed: 17
device_train_microbatch_size: 1
device_eval_batch_size: 1
precision: amp_bf16
autoresume: true
dist_timeout: 3600
expandable_segments: true

# FSDP
fsdp_config:
  mixed_precision: PURE
  state_dict_type: sharded
  limit_all_gathers: true
  sharding_strategy: FULL_SHARD
  activation_cpu_offload: false
  activation_checkpointing: true
  activation_checkpointing_reentrant: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

# Callbacks
callbacks:
  lr_monitor: {}
  speed_monitor:
    window_size: 1
  memory_monitor: {}
  hf_checkpointer:
    overwrite: true
    precision: bfloat16
    save_folder: ./{run_name}/checkpoints
    save_interval: 1dur
  runtime_estimator: {}
composer train/train.py \
  train/yamls/finetune/dbrx-full-ft.yaml \
  data_local=/home/.../dbrx-asset/c4/large \
  train_loader.dataset.split=train \
  eval_loader.dataset.split=test \
  max_duration=10ba \
  eval_interval=0 \
  save_folder=/home/.../dbrx-asset/dbrx-asset