English | 简体中文
MMEngine is a foundational library for training deep learning models based on PyTorch. It provides a solid engineering foundation and frees developers from writing redundant codes on workflows. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms on various research areas. Importantly, MMEngine is also generic to be applied to non-OpenMMLab projects.
Major features:
-
Universal and powerful runner.
- Fewer code, e.g., train ImageNet with 1/5 lines of code compared with PyTorch example.
- Compatible with popular libraries like OpenMMLab, TorchVision, timm and Detectron2.
-
Open architecture with unified interfaces.
- Handle different algorithm tasks with unified API, e.g., implement a method and apply it to all compatible models.
- Support different devices and hardwares with unified API, including CPU, GPU, IPU, Apple silicon, etc.
-
Customizable training process.
- Define the training process like playing with Legos. Rich components and strategies are available.
- Complete control of training with different level of APIs.
Before installing MMEngine, please ensure that PyTorch has been successfully installed following the official guide.
Install MMEngine
pip install -U openmim
mim install mmengine
Verify the installation
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
As an example of training a ResNet-50 model on the CIFAR-10 dataset, we will build a complete, configurable training and validation process using MMEngine in less than 80 lines of code.
Build Models
First, we need to define a Model that 1) inherits from BaseModel
, and 2) accepts an additional argument mode
in the forward
method, in addition to those arguments related to the dataset. During training, the value of mode
is "loss" and the forward
method should return a dict containing the key "loss". During validation, the value of mode
is "predict" and the forward method should return results containing both predictions and labels.
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
Build Datasets
Next, we need to create a Dataset and DataLoader for training and validation. In this case, we simply use built-in datasets supported in TorchVision.
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
Build Metrics
To validate and test the model, we need to define a Metric like accuracy to evaluate the model. This metric needs inherit from BaseMetric
and implements the process
and compute_metrics
methods.
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# Save the results of a batch to `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict(accuracy=100 * total_correct / total_size)
Build a Runner
Finally, we can construct a Runner with previously defined Model, DataLoader, Metrics and some other configs, as shown below.
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
# a wapper to execute back propagation and gradient update, etc.
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# set some training configs like epochs
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
Launch Training
runner.train()
We appreciate all contributions to improve MMEngine. Please refer to CONTRIBUTING.md for the contributing guideline.
This project is released under the Apache 2.0 license.
- MIM: MIM installs OpenMMLab packages.
- MMCV: OpenMMLab foundational library for computer vision.
- MMClassification: OpenMMLab image classification toolbox and benchmark.
- MMDetection: OpenMMLab detection toolbox and benchmark.
- MMDetection3D: OpenMMLab's next-generation platform for general 3D object detection.
- MMRotate: OpenMMLab rotated object detection toolbox and benchmark.
- MMSegmentation: OpenMMLab semantic segmentation toolbox and benchmark.
- MMOCR: OpenMMLab text detection, recognition, and understanding toolbox.
- MMPose: OpenMMLab pose estimation toolbox and benchmark.
- MMHuman3D: OpenMMLab 3D human parametric model toolbox and benchmark.
- MMSelfSup: OpenMMLab self-supervised learning toolbox and benchmark.
- MMRazor: OpenMMLab model compression toolbox and benchmark.
- MMFewShot: OpenMMLab fewshot learning toolbox and benchmark.
- MMAction2: OpenMMLab's next-generation action understanding toolbox and benchmark.
- MMTracking: OpenMMLab video perception toolbox and benchmark.
- MMFlow: OpenMMLab optical flow toolbox and benchmark.
- MMEditing: OpenMMLab image and video editing toolbox.
- MMGeneration: OpenMMLab image and video generative models toolbox.
- MMDeploy: OpenMMLab model deployment framework.