Tensorflow Object Detection API provides a collection of detection models pre-trained on the COCO dataset, the Kitti dataset, the Open Images dataset,etc. These models can be useful for out-of-the-box inference or initializing your models when training on custom datasets.
Typically, the following steps are required to train on a custom dataset:
- Installation.
- Anaconda python=3.7 (optional)
- tensorflow 1.15
- Required python packages (i.e., requirements.txt)
- Clone Tensorflow model garden
- Download and install protobuf
- Prepare custom dataset
- Prepares images (.jpg format)
- Make annotations (.xml format) for each image
- Combine annotation .xml files into a .csv file for train and test set
- Create label map file .pbtxt
- Generate tf records from such datasets
- Download pre-trained models
- Configure training pipeline (e.g., edit pipeline.config file)
- Prepare options and run ```python object_detection/model_main.py```
This repos aims to process all of the required steps mentioned above in one bash command
- Anaconda3
- Python 3.7
bash main.sh --config_file {PATH_TO_CONFIG_FILE} --install
main.sh [--install] [--config_file str] [--env str]
Implement tensorflow (tf-gpu 1.15.0) Object Detection API on custom Dataset.
Requirements: Anaconda3, python 3.7
Args:
-h|--help Show this help message
--install Install conda env, pip requirements, and tf models API.
Default: False
--config_file str Path to config file.
Default: /tf_detection_api/config.json
--env str Conda environment name.
Default: tf1Detection
The contents of the config .json file is as belows:
{
"image_dir": "PATH_TO_IMAGES_DIRECTORY",
"anno_dir": "PATH_TO_ANNOTATIONS_DIRECTORY",
"id":[1,2],
"name":["cat","dog"],
"pretrained_model": "ssd_mobilenet_v1_coco",
"train_ratio":0.75
}
-
image_dir
: absolute path to the local directory that contains all of the images (both train and test sets). If you plan to download images from google search, Selenium could be a good choice to automate this process. Detailed isntructions can be found in [4] -
anno_dir
: absolute path to annotations directory that contains .xml annotaiton files. labelImg is a nice tool to generate the annotation .xml file from input images. Details can be found in [5] -
id
: list of category IDs from your custom dataset. 📝 Note that the id starts from 1 because id=0 is used for background as default. -
name
: list of category names from custom dataset -
pretrained_model
: pre-trained model name from tensorflow model garden. Full list of pretrained models
train_ratio
: Ratio to split train and test sets. Default: 0.75
[1] Tensorflow model garden installation https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
[2] Setup for custom dataset https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md
[3] Run the traning job https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_locally.md
[4] Search and Download image from google with Python and Selenium https://towardsdatascience.com/image-scraping-with-python-a96feda8af2d
[5] Label images with labelImg tool https://github.com/tzutalin/labelImg