DeepSpeed Integration
Closed this issue · 2 comments
Hi,
Since my GPU memory is low (12GB), I am finding the way to use deepspeed in training code, with CPU offload setting.
Here is my modification so far:
"""
Fine-Tune SantaCoder on code/text dataset
"""
import argparse
import os
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
TrainerCallback,
TrainerState,
TrainerControl,
logging,
set_seed,
)
import deepspeed
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="bigcode/santacoder")
parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-dedup")
parser.add_argument("--subset", type=str, default="data")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--size_valid_set", type=int, default=4000)
parser.add_argument("--streaming", action="store_true")
parser.add_argument("--shuffle_buffer", type=int, default=5000)
parser.add_argument("--data_column", type=str, default="content")
parser.add_argument("--seq_length", type=int, default=1024)
parser.add_argument("--max_steps", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--eos_token_id", type=int, default=49152)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--num_warmup_steps", type=int, default=100)
parser.add_argument("--weight_decay", type=float, default=0.05)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--no_fp16", action="store_false")
parser.add_argument("--no_gradient_checkpointing", action="store_false")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--output_dir", type=str, default="./checkpoints")
parser.add_argument("--log_freq", default=1, type=int)
parser.add_argument("--eval_freq", default=1000, type=int)
parser.add_argument("--save_freq", default=1000, type=int)
return parser.parse_args()
def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
total_characters += len(example[data_column])
total_tokens += len(tokenizer(example[data_column]).tokens())
return total_characters / total_tokens
DEEPSPEED_CONFIG = \
{
'optimizer': {'type': 'AdamW', 'params': {'lr': 1e-05, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0.0}},
'scheduler': {'type': 'WarmupLR', 'params': {'warmup_min_lr': 0, 'warmup_max_lr': 1e-05, 'warmup_num_steps': 100}},
'zero_optimization': {
'stage': 3,
'offload_optimizer': {'device': 'cpu', 'pin_memory': False},
'offload_param': {'device': 'cpu', 'pin_memory': False},
'overlap_comm': True,
'contiguous_gradients': True,
'sub_group_size': 1e9,
'reduce_bucket_size': 16777216,
'stage3_prefetch_bucket_size': 15099494.4,
'stage3_param_persistence_threshold': 40960,
'stage3_max_live_parameters': 1e9,
'stage3_max_reuse_distance': 1e9,
},
'train_batch_size': 32,
'train_micro_batch_size_per_gpu': 4,
'gradient_accumulation_steps': 8,
'gradient_clipping': 1.0,
'steps_per_print': 8,
'wall_clock_breakdown': False,
'compression_training': {'weight_quantization': {'shared_parameters': {}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {}, 'different_groups': {}}}
}
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
content_field="content",
):
self.tokenizer = tokenizer
self.concat_token_id = (
tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
)
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = content_field
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
}
def create_datasets(tokenizer, args):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
split=args.split,
use_auth_token=True,
num_proc=args.num_workers if not args.streaming else None,
streaming=args.streaming,
)
if args.streaming:
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
train_data = dataset["train"]
valid_data = dataset["test"]
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
)
chars_per_token = chars_token_ratio(train_data, tokenizer, args.data_column)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
infinite=True,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
content_field=args.data_column,
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
infinite=False,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
content_field=args.data_column,
)
return train_dataset, valid_dataset
class SantaCoderTrainerCallback(TrainerCallback):
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
torch.cuda.empty_cache()
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
torch.cuda.empty_cache()
def run_training(args, train_data, val_data):
print("Loading the model")
# disable caching mechanism when using gradient checkpointing
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
trust_remote_code=True,
use_cache=not args.no_gradient_checkpointing,
)
train_data.start_iteration = 0
print(f"Starting main loop")
DEEPSPEED_CONFIG['train_micro_batch_size_per_gpu'] = args.batch_size
DEEPSPEED_CONFIG['gradient_accumulation_steps'] = args.gradient_accumulation_steps
DEEPSPEED_CONFIG['train_batch_size'] = args.batch_size * args.gradient_accumulation_steps
DEEPSPEED_CONFIG['scheduler']['params']['warmup_num_steps'] = args.num_warmup_steps
DEEPSPEED_CONFIG['scheduler']['params']['warmup_max_lr'] = args.learning_rate
DEEPSPEED_CONFIG['optimizer']['params']['lr'] = args.learning_rate
DEEPSPEED_CONFIG['optimizer']['params']['weight_decay'] = args.weight_decay
training_args = TrainingArguments(
output_dir=args.output_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=args.max_steps,
eval_steps=args.eval_freq,
save_steps=args.save_freq,
logging_steps=args.log_freq,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.no_gradient_checkpointing,
fp16=args.no_fp16,
weight_decay=args.weight_decay,
run_name=f"santacoder-{args.subset}",
report_to="wandb",
deepspeed=DEEPSPEED_CONFIG
)
trainer = Trainer(
model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data, callbacks=[SantaCoderTrainerCallback]
)
print("Training...")
trainer.train()
print("Saving last checkpoint of the model")
output_dir = os.path.join(args.output_dir, "final_checkpoint/")
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
def main(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_auth_token=True)
train_dataset, eval_dataset = create_datasets(tokenizer, args)
run_training(args, train_dataset, eval_dataset)
if __name__ == "__main__":
args = get_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
logging.set_verbosity_error()
main(args)
Could you help me to check if I am doing it in right way, Thanks ^^ The DeepSpeed config is inherited from https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py
It’s hard to see what exactly changed in this format, but it looks ok you just don't need to supply the training arguments twice (with and without deepspeed) and need to adjust the command line with deepspeed (see docs).
You can try it and report if you face issues here on in transformers. You can also find an example on how to use DeepSpeed with the Trainer in this thread.
Ah I see your point, thank you ^^