- Python 3.8
- PyTorch 1.7.1
- numpy
- pillow
- opencv
- matplotlib
- tqdm
Run data/gen_dataset-colorjitter-smooth.ipynb
to generate the dataset.
python train_ta-res18-unet-cml.py --epochs 100 --batch-size 4 --learning-rate 1e-5 --classes 1 --channels 2 --scale 0.5 --bilinear
The trained models can be found in data
.
python evaluate2.py --model ./data/TAres18unet-pre-cml.pth --name TAResnet18_Unet --input_sar data/dataset/trainval_imgs/ --input_mask data/dataset/trainval_masks/ --output ./result_eval/out_ResUNet-TAM-CML.csv --classes 1 --channels 2 --scale 0.5 --bilinear --batch_size 4
python predict.py --model ./data/TAres18unet-pre-cml.pth --name TAResnet18_Unet --input_sar data/demo/ --input_mask data/demo/ --output data/demo/output/ --classes 1 --channels 2 --scale 0.5 --bilinear
This code is built on U-Net: Semantic segmentation with PyTorch. We thank the authors for sharing their codes.