This is a small project on semantic segmentation using a synthetic dataset obtained from Unreal Engine. The task is to create a synthetic dataset from Unreal Engine using the plugin Unrealcv and recognize objects in the image. The dataset is composed by RGBD (RGB + depth) images and ground truth. The environment is a simple room with some couches, two TVs, two plants and some tables. Once collected the dataset it is possible to detect and localize objects in images. This can be done using a semantic segmentation network. Semantic segmentation networks basically perform pixel-wise classification assigning for each pixel its class.
Inside semantic-seg the main files are:
- generate_dataset.py : Script for generating dataset interacting with the unrealcv server.
- write_dataset_to_tf.py : Read the dataset written previously and store in tfrecord files.
- train.py : network training using a modified version of the Semantic-Segmentation-Suite
Clone the repo and install requirements:
git clone https://github.com/EmanueleGhelfi/semantic-seg/
cd semantic-seg
conda create -n semantic-seg
pip install -r requirements.txt
Download the Unreal enviroment from here. Open the project with unreal.
install the plugin Unrealcv following the instructions. Start a new project and import the desired objects.
Start the environment defined in the Step 0. Configure the generate_dataset.py
script in order to connect to the correct port. To check the unrealcv port type the following command in the unreal console while the environment is running:
vget /unrealcv/status
Check if the server is listening, check the unreal port and that no client is connected.
By default unrealcv generates a different label (color) for each object in the scene based on its id. Object detection is the task of recognizing the object class, not the object instance. The file conf.json defines the class for each object of interest in the scene, based on its id. When the script starts it configures unrealcv by setting the defined color for each object and setting all other object's color to black.
Launch:
python generate_dataset.py
This should save the rgb images, the depth images and associated labels inside the folder dataset
. The script moves randomly the camera inside a predefined range and takes photos. A check is added in order to not take empty photos.
Example of obtained image with the corresponding label:
Image | Label |
---|---|
The default configuration takes 5000 photos.
Convert the generated dataset to tfrecord files.
python write_dataset_to_tf.py
This should save the training, validation and test sets inside dataset/tfrecord/. Images are saved in RGBD format. Labels are saved in color index format. For each pixel there is a unique label depending on the object. In this way once imported in tensorflow it is enough to transform the label using tf.one_hot
. The default configuration divides the dataset using 60% for training, 20% for validation and 20% for test. Each .tfrecord file contains 10 RGBD images and the corresponding labels.
Specify inside the file dataset/class_dict.csv
the object classes to recognize:
name,r,g,b
void,0,0,0
tv,255,255,0
couch,255,0,0
plant,0,255,0
table,0,0,255
The class void is needed since the net should output for each pixel a probability for each class.
Train the selected semantic segmentation network:
python train.py --dataset dataset --model MobileUNet
The training is implemented using the tf.data API. Images are loaded from the tfrecord file saved before and resized. The script performs evaluation of some batch of images saves the net weights and saves plots. TODO: Implement saving in tensorflow format.
Test the network on the test set:
python test.py --dataset dataset --model MobileUNet --checkpoint_path checkpoint_path
The following results are obtained using MobileUNet.
Epoch 1:
Network Output | Ground Truth |
---|---|
Epoch 70:
Network Output | Ground Truth |
---|---|
Coming soon: test results and other models...
Check the folder: https://drive.google.com/open?id=15tp9lZVJVC44LCzfjv22Y1kfXqk4iSKu to see some results. Check {ModelName}/checkpoints/ and look inside the folders numbered with a multiple of 5 (sorry). Images named with _gt are ground truths.
You can also check the notebook: https://colab.research.google.com/drive/1CdDyMWF3vNAPz-BbRXSg0eK_SvAkzSgF