This repository contains the code for the project Real-time domain adaptation for semantic segmentation
, relative to the course Advanced Machine Learning
. The repo contains also the assignment (with tables filled with the values obtained) and a report which elaborates on methods, results and conclusions.
- The first goal of the project is to implement and test BiSeNet, a deep network for semantic segmentation, on Cityscapes. The description of the network is in the folder
model
, while the file to train it on the labeled dataset istrain.py
. - Secondly, the projects aims at training the network on a domain-adaptation task. In particular, the network is trained using the labeled GTA5 dataset as source domain and the unlabeled Cityscapes as target domain. A discriminator network to distinguish between the two domains and help in learning meaningful representations is described in
model/discriminator.py
, whereas the file to perform the training isdomain_adaptation_train.py
. - In conclusion, the performances of domain adaptation are improved by implementing a pseudo labeling technique. In particular, pseudo labels are generated for the target domain (Cityscapes) and are used for training in the next iteration. The file to perform the training is
pseudo_labels_train.py
, whereas the file to generate pseudo labels isSSL.py
.
Some predictions for the different models are reported below:
The images correspond, in order, to:- ground truth
- baseline (BiSeNet trained on labelled Cityscapes)
- domain adaptation with standard discriminator
- domain adaptation with lightweight discriminator
- domain adaptation with pseudo labels and fixed threshold
- domain adaptation with pseudo labels and variable threshold
demo.py
provides functions to save a png image with the original image overlapped to the label prediction of a model.eval.py
is used to perform evaluation on the test datasetloss.py
contains functions used to compute the lossmake_plots.py
is used to create plots of the losses during trainingutils.py
contains useful functions to compute accuracy and mIoU of predictions