/Advanced-Machine-Learning

Domain adaptation in real-time semantic segmentation. Project for the course Advanced Machine Learning.

Primary LanguagePython

Advanced-Machine-Learning

This repository contains the code for the project Real-time domain adaptation for semantic segmentation, relative to the course Advanced Machine Learning. The repo contains also the assignment (with tables filled with the values obtained) and a report which elaborates on methods, results and conclusions.

Goals

  • The first goal of the project is to implement and test BiSeNet, a deep network for semantic segmentation, on Cityscapes. The description of the network is in the folder model, while the file to train it on the labeled dataset is train.py.
  • Secondly, the projects aims at training the network on a domain-adaptation task. In particular, the network is trained using the labeled GTA5 dataset as source domain and the unlabeled Cityscapes as target domain. A discriminator network to distinguish between the two domains and help in learning meaningful representations is described in model/discriminator.py, whereas the file to perform the training is domain_adaptation_train.py.
  • In conclusion, the performances of domain adaptation are improved by implementing a pseudo labeling technique. In particular, pseudo labels are generated for the target domain (Cityscapes) and are used for training in the next iteration. The file to perform the training is pseudo_labels_train.py, whereas the file to generate pseudo labels is SSL.py.

Results

Some predictions for the different models are reported below:

drawing drawing drawing drawing drawing drawing
The images correspond, in order, to:
  • ground truth
  • baseline (BiSeNet trained on labelled Cityscapes)
  • domain adaptation with standard discriminator
  • domain adaptation with lightweight discriminator
  • domain adaptation with pseudo labels and fixed threshold
  • domain adaptation with pseudo labels and variable threshold

Additional files

  • demo.py provides functions to save a png image with the original image overlapped to the label prediction of a model.
  • eval.py is used to perform evaluation on the test dataset
  • loss.py contains functions used to compute the loss
  • make_plots.py is used to create plots of the losses during training
  • utils.py contains useful functions to compute accuracy and mIoU of predictions