[Feature Request] Example of training CNN with large batch size
elricwan opened this issue · 1 comments
Is your feature request related to a problem? Please describe.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
Hi there, I have recently try to train a resnet50 with batch size = 8192 using hivemind, unfortunately the loss did not decrease.
Describe the solution you'd like
A clear and concise description of what you want to happen.
Is there any wrtient example that I could take a look for training neural networks other than transformer? (By that I mean do not use huggingface trainer.)
Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
Additional context
Add any other context or screenshots about the feature request here.
i attacehd my code below for reference:
import os
import pickle
import sys
from dataclasses import asdict
from pathlib import Path
import torch
import transformers
from torch.utils.data import DataLoader
from torch_optimizer import Lamb
from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
from transformers import BertForMaskedLM, BertConfig, BertConfig, AutoTokenizer
from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.trainer_utils import is_main_process
import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
import utils
from arguments import (
ModelTrainingArguments,
AveragerArguments,
CollaborationArguments,
DatasetArguments,
ProgressTrackerArguments,
)
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_IMAGENET_PCA = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
'eigvec': torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
def setup_transformers_logging(process_rank: int):
if is_main_process(process_rank):
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.disable_default_handler()
transformers.utils.logging.enable_propagation()
def get_model(args):
logger.info(f"Training from scratch")
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
return model
# wrap the optimizer with gradient clipping
class LambWithGradientClipping(Lamb):
"""A version of LAMB that clips gradients based on their norm."""
def __init__(self, *args, max_grad_norm=1.0, **kwargs):
self.max_grad_norm = max_grad_norm
super().__init__(*args, **kwargs)
def step(self, *args, **kwargs):
iter_params = (param for group in self.param_groups for param in group["params"])
torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
return super().step(*args, **kwargs)
class CollaborativeCallback(transformers.TrainerCallback):
"""
This callback monitors and reports collaborative training progress.
In case of a catastrophic failure, it can also revert training to a backup.
"""
def __init__(
self,
dht: DHT,
optimizer: Optimizer,
model: torch.nn.Module,
local_public_key: bytes,
statistics_expiration: float,
backup_every_steps: int,
):
super().__init__()
self.model = model
self.dht, self.optimizer = dht, optimizer
self.local_public_key = local_public_key
self.statistics_expiration = statistics_expiration
self.last_reported_collaboration_step = -1
self.samples = 0
self.steps = 0
self.loss = 0
self.total_samples_processed = 0
self.backup_every_steps = backup_every_steps
self.latest_backup = self.backup_state()
def on_train_begin(
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
):
logger.info("Loading state from peers")
self.optimizer.load_state_from_peers()
def on_step_end(
self, loss, **kwargs
):
if not self.params_are_finite():
self.restore_from_backup(self.latest_backup)
local_progress = self.optimizer.local_progress
self.loss += loss
self.steps += 1
if self.optimizer.local_epoch != self.last_reported_collaboration_step:
self.last_reported_collaboration_step = self.optimizer.local_epoch
self.total_samples_processed += self.samples
samples_per_second = local_progress.samples_per_second
statistics = utils.LocalMetrics(
step=self.optimizer.local_epoch,
samples_per_second=samples_per_second,
samples_accumulated=self.samples,
loss=self.loss,
mini_steps=self.steps,
)
logger.info(f"Step #{self.optimizer.local_epoch}")
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
if self.steps:
logger.info(f"Local loss: {self.loss / self.steps:.5f}")
if self.optimizer.local_epoch % self.backup_every_steps == 0:
self.latest_backup = self.backup_state()
self.loss = 0
self.steps = 0
if self.optimizer.is_synchronized_with_peers():
self.dht.store(
key=self.optimizer.run_id + "_metrics",
subkey=self.local_public_key,
value=statistics.dict(),
expiration_time=get_dht_time() + self.statistics_expiration,
return_future=True,
)
self.samples = local_progress.samples_accumulated
@torch.no_grad()
def params_are_finite(self):
for param in self.model.parameters():
if not torch.all(torch.isfinite(param)):
return False
return True
@torch.no_grad()
def backup_state(self) -> bytes:
return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
@torch.no_grad()
def restore_from_backup(self, backup: bytes):
state = pickle.loads(backup)
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
class NoOpScheduler(LRSchedulerBase):
"""Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
def get_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, *args, **kwargs):
if self.optimizer.scheduler:
return self.optimizer.scheduler.print_lr(*args, **kwargs)
def step(self):
self._last_lr = self.get_lr()
def state_dict(self):
return {}
def load_state_dict(self, *args, **kwargs):
logger.debug("Called NoOpScheduler.load_state_dict")
def train(train_loader, model, criterion, optimizer, epoch, collaborative_call):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
images = images.cuda(device, non_blocking=True)
target = target.cuda(device, non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# at the end of the step: on_step_end
collaborative_call.on_step_end(loss=loss.item())
# display the accuracy
#progress.display(i)
def validate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.cuda(device, non_blocking=True)
target = target.cuda(device, non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def adjust_learning_rate(optimizer, epoch, training_args):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = training_args.learning_rate * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone()\
.mul(alpha.view(1, 3).expand(3, 3))\
.mul(self.eigval.view(1, 3).expand(3, 3))\
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))
def main():
parser = HfArgumentParser(
(
ModelTrainingArguments,
DatasetArguments,
CollaborationArguments,
AveragerArguments,
ProgressTrackerArguments,
)
)
training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
setup_transformers_logging(training_args.local_rank)
logger.info(f"Training/evaluation parameters:\n{training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
model = get_model(training_args)
model.cuda(device)
validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
dht = DHT(
start=True,
initial_peers=collaboration_args.initial_peers,
client_mode=collaboration_args.client_mode,
record_validators=validators,
use_ipfs=collaboration_args.use_ipfs,
host_maddrs=collaboration_args.host_maddrs,
announce_maddrs=collaboration_args.announce_maddrs,
identity_path=collaboration_args.identity_path,
wait_timeout=10
)
utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
if torch.cuda.device_count() != 0:
total_batch_size_per_step *= torch.cuda.device_count()
adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
# We need to make such a lambda function instead of just an optimizer instance
# to make hivemind.Optimizer(..., offload_optimizer=True) work
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(device)
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(dataset_args.dataset_path, 'train')
valdir = os.path.join(dataset_args.dataset_path, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
transforms.ToTensor(),
Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
normalize,
]))
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=training_args.per_device_train_batch_size, shuffle=(train_sampler is None),
num_workers=4, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=training_args.per_device_train_batch_size, shuffle=False,
num_workers=4, pin_memory=True)
opt = lambda params: torch.optim.SGD(
params,
lr=training_args.learning_rate,
momentum=0.9,
weight_decay=1e-4,
)
no_decay = ["bias", "LayerNorm.weight"]
params = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": training_args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
scheduler = lambda opt: get_linear_schedule_with_warmup(
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
)
optimizer = Optimizer(
dht=dht,
run_id=collaboration_args.experiment_prefix,
target_batch_size=adjusted_target_batch_size,
batch_size_per_step=total_batch_size_per_step,
optimizer=opt,
params=params,
#scheduler=scheduler,
matchmaking_time=collaboration_args.matchmaking_time,
averaging_timeout=collaboration_args.averaging_timeout,
offload_optimizer=True,
delay_optimizer_step=True,
delay_grad_averaging=True,
client_mode=collaboration_args.client_mode,
grad_compression=Float16Compression(),
state_averaging_compression=Float16Compression(),
averager_opts={"bandwidth": collaboration_args.bandwidth, "request_timeout": 5, **asdict(averager_args)},
tracker_opts=asdict(tracker_args),
verbose=True,
)
collaborative_call = CollaborativeCallback(
dht,
optimizer,
model,
local_public_key,
collaboration_args.statistics_expiration,
collaboration_args.backup_every_steps,
)
best_acc1 = 0
for epoch in range(90):
adjust_learning_rate(optimizer, epoch, training_args)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, collaborative_call)
# evaluate on validation set
acc1 = validate(val_loader, model, criterion)
best_acc1 = max(acc1, best_acc1)
if __name__ == "__main__":
main()
Hi!
Short-term: in previous hivemind versions, there was a bug that prevented training when there was only one peer, which may have been the problem in your case.
To check, install hivemind from github: pip install https://github.com/learning-at-home/hivemind/archive/master.zip
Also, please check if it's offloading by setting offload_optimizer=False, delay_optimizer_step=False, delay_grad_averaging=False,
Long-term: we're (steadily) working towards an image example in #459
Also, we've got a quickstart example with a simple CNN here