PyTorch implementation of UNet for semantic segmentation of aerial imagery
- This repository enables training UNet with various encoders like ResNet18, ResNet34, etc.
- Uses a compound (Cross-Entropy + Jaccard loss) loss to train the network.
- You can quickly use a custom dataset to train the model.
- Contains a self-supervised method to train network encoder on unlabeled data (Upcoming task).
Example of the network outputs visualization
-
Clone the repository
git clone https://github.com/Niloofaresf1996/field_segmentation.git
We call this directory as
$RESA_ROOT
-
Create an environment and activate it (We've used conda. but it is optional)
conda create -n aiss python=3.9 -y conda activate aiss
-
Install dependencies
# Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision) conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch # Install following libraries pip install opencv-python pip install numpy pip install matplotlib pip install segmentation_models_pytorch
-
Download and extract dataset Download and extract this Kaggle Dataset
Note: this repository is still developing, and this dataset is used for testing the model. Our model could achieve 81.0% segmentation accuracy on this dataset.
semantic-segmentation-of-aerial-imagery/ ├── Tile 1 ├── images ├── image_part_001.jpg ├── image_part_002.jpg ├── ... └── masks ├── image_part_001.png ├── image_part_002.png ├── ... ├── Tile 2 ├── . ├── . └── Tile 9 └── classes
The following command is prepared as an example for training the network. You can customize the parameters to train the default version.
python train.py --dataset_path Semantic_segmentation_dataset --encoder resnet34 --encoder_weights imagenet --gpu_id 0 --gpus 1
- Inference Demo with a Pre-trained model. You can download our pretrain weights from here and customize the following command to run the demo
python demo.py --checkpoint checkpoints/best.pth --image_path demo/sample_1.jpg --fname demo/result_1.jpg --gpu_id 0