Fine-tune dbrx-instruct on a single VM with 8 H100s
hosseinsarshar opened this issue · 1 comments
hosseinsarshar commented
I'm trying to fine-tune dbrx on a single machine with 8 H100 gpus. I keep getting OOM error with different configurations, I wonder if this is even doable.
I see a note that suggests 64x80 GPUs, but I wonder if there are ways to do it with 8X80 GB gpus.
Thanks
hosseinsarshar commented
Some other info about my job:
My config:
# Note: This requires ~64x80GB GPUs
max_seq_len: 64
icl_seq_len: 64
# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME
# Model
model:
name: hf_causal_lm
pretrained: true
init_device: mixed
use_auth_token: true
config_overrides: {}
use_flash_attention_2: true
pretrained_model_name_or_path: databricks/dbrx-instruct
# Tokenizer
tokenizer:
name: databricks/dbrx-instruct
kwargs:
model_max_length: ${max_seq_len}
trust_remote_code: true
# Dataloaders
train_loader:
name: finetuning
dataset:
split: train
hf_name: mosaicml/dolly_hhrlhf
shuffle: true
max_seq_len: ${max_seq_len}
eos_token_id: 0
packing_ratio: auto
allow_pad_trimming: false
decoder_only_format: true
drop_last: true
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true
eval_loader:
name: finetuning
dataset:
split: test
hf_name: mosaicml/dolly_hhrlhf
shuffle: false
max_seq_len: ${max_seq_len}
packing_ratio: null
allow_pad_trimming: false
decoder_only_format: true
drop_last: true
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true
# Optimization
optimizer:
lr: 0.000001
name: decoupled_lionw
betas:
- 0.9
- 0.95
weight_decay: 1.0e-06
scheduler:
name: cosine_with_warmup
alpha_f: 0
t_warmup: 0.02dur
algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1
max_duration: 2ep
eval_interval: 1ep
global_train_batch_size: 8
eval_first: false
# eval_subset_num_batches: -1
# System
seed: 17
device_train_microbatch_size: 1
device_eval_batch_size: 1
precision: amp_bf16
autoresume: true
dist_timeout: 3600
expandable_segments: true
# FSDP
fsdp_config:
mixed_precision: PURE
state_dict_type: sharded
limit_all_gathers: true
sharding_strategy: FULL_SHARD
activation_cpu_offload: false
activation_checkpointing: true
activation_checkpointing_reentrant: false
# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba
# Callbacks
callbacks:
lr_monitor: {}
speed_monitor:
window_size: 1
memory_monitor: {}
hf_checkpointer:
overwrite: true
precision: bfloat16
save_folder: ./{run_name}/checkpoints
save_interval: 1dur
runtime_estimator: {}
composer train/train.py \
train/yamls/finetune/dbrx-full-ft.yaml \
data_local=/home/.../dbrx-asset/c4/large \
train_loader.dataset.split=train \
eval_loader.dataset.split=test \
max_duration=10ba \
eval_interval=0 \
save_folder=/home/.../dbrx-asset/dbrx-asset