/pytorch-U-Net

Pytorch Implementation of U-Net on Cityscapes Dataset

Primary LanguagePython

U-Net pytorch Implementation

This repository contains an implementation of the U-Net architecture in PyTorch. The implemented network is trained on the Cityscapes dataset. The number of classes used for training is currently four (road, sky, car, unlabeled). The codes related to the architecture is in model.py and blocks.py. Blocks for the contracting and expanding paths are defined in blocks.py and are combined in model.py.

Examples

sample1 sample2

Guide

Clone this repository:

git clone https://github.com/finallyupper/pytorch-U-Net 

Create a virtual environment and install dependencies:

conda create -n unet python=3.8
conda activate unet
pip install -r requirements.txt

NOTE : Modify the configurations to suit your situation at configs.yaml.

Run the following command to start training the model:

python train.py

At the beginning, you are required to login the wandb.

Run the following command to start testing the model:

python inference.py

Results

Dataset

TO-DO

  • Define Contracting/Expansive Path
  • Define customed Cityscapes dataset and dataloader
  • Add additional functions
  • Train / Test U-Net
  • Results, Hyperparameter Tunings
  • Refactoring