CenterNet is a strong single-stage, single-scale, and anchor-free object detector. This implementation is built with PyTorch Lightning, supports TorchScript and ONNX export, and has modular design to make customizing components simple.
References
To read more about the architecture and code structure of this implementation, see implementation.md
Clone this repo and navigate to the repo directory
git clone https://github.com/gau-nernst/centernet-lightning.git
cd centernet-lightningInstall using environment.yml
conda env create -f environment.yml
conda activate centernetFor more detailed instructions, see install.md
Import build_centernet from models to build a CenterNet model from a YAML file. Sample config files are provided in the configs/ directory.
from centernet_lightning.models import build_centernet
model = build_centernet("configs/coco_resnet34.yaml")You also can load a CenterNet model directly from a checkpoint thanks to PyTorch Lightning.
from centernet_lightning.models import CenterNet
model = CenterNet.load_from_checkpoint("path/to/checkpoint.ckpt")Use CenterNet.inference_detection() or CenterNet.inference_tracking()
model = ... # create a model as above
img_dir = "path/to/img/dir"
detections = model.inference_detection(img_dir, num_detections=100)detections is a dictionary with the following keys:
| Key | Description | Shape |
|---|---|---|
bboxes |
bounding boxes in x1y1x2y2 format | (num_images x num_detections x 4) |
labels |
class labels | (num_images x num_detections) |
scores |
confidence scores | (num_images x num_detections) |
Results are np.ndarray, ready for post-processing.
This is useful when you use CenterNet in your own applications
import numpy as np
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
# read image
img = cv2.imread("path/to/image")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# apply pre-processing: resize to 512x512 and normalize with ImageNet statistics
# use torchvision.transforms should work also
transforms = A.Compose([
A.Resize(height=512, width=512),
A.Normalize(),
ToTensorV2()
])
img = transforms(image=img)["image"]
# create a model as above and put it in evaluation mode
model = ...
model.eval()
# turn off gradient calculation and do forward pass
with torch.no_grad():
encoded_outputs = model(img.unsqueeze(0))
detections = model.gather_detection2d(encoded_outputs)detections has the same format as above, but the values are torch.Tensor.
Note: Due to data augmentations during training, the model is robust enough to not need ImageNet normalization in inference. You can normalize input image to [0,1] and CenterNet should still work fine.
CenterNet is export-friendly. You can directly export a trained model to ONNX or TorchScript (only tracing) using PyTorch Lightning API
import torch
from centernet_lightning.models import CenterNet
model = CenterNet.load_from_checkpoint("path/to/checkpoint.ckpt")
model.to_onnx("model.onnx", torch.rand((1,3,512,512))) # export to ONNX
model.to_torchscript("model.pt", method="trace") # export to TorchScript. scripting might not workWIP
You can train CenterNet with the provided train script train.py and a config file.
python train.py --config "configs/coco_resnet34.yaml"See sample config files at configs/. To customize training, see training.md
The following dataset formats are supported:
Detection:
Tracking:
To see how to use each dataset type, see datasets.md