PyTorch re-implementation of the WaSR network [1]. Contains training code, prediction code and models pretrained on the MaSTr1325 dataset [2].
Requirements: Python >= 3.6, PyTorch, PyTorch Lightning (for training)
Install the dependencies provided in requirements.txt
.
pip install -r requirements.txt
Currently available pretrained model weights. All models are trained on the MaSTr1325 dataset and evaluated on the MODS benchmark [3].
model | backbone | IMU | url |
---|---|---|---|
wasr_resnet101 | ResNet-101 | weights | |
wasr_resnet101_imu | ResNet-101 | ✓ | weights |
- Download and prepare the MaSTr1325 dataset (images and GT masks). If you plan to use the IMU-enabled model also download the IMU masks.
- Edit the dataset configuration (
configs/mastr1325_train.yaml
,configs/mastr1325_val.yaml
) files so that they correctly point to the dataset directories. - Use the
train.py
to train the network.
export CUDA_VISIBLE_DEVICES=0,1,2,3 # GPUs to use
python train.py \
--train_config configs/mastr1325_train.yaml \
--val_config configs/mastr1325_val.yaml \
--model_name my_wasr \
--validation \
--batch_size 4 \
--epochs 50
By default the ResNet-101, IMU-enabled version of the WaSR is used in training. To select a different model architecture use the --model
argument. Currently implemented model architectures:
model | backbone | IMU |
---|---|---|
wasr_resnet101_imu | ResNet-101 | ✓ |
wasr_resnet101 | ResNet-101 | |
wasr_resnet50_imu | ResNet-50 | ✓ |
wasr_resnet50 | ResNet-50 | |
deeplab | ResNet-101 |
A log dir with the specified model name will be created inside the output
directory. Model checkpoints and training logs will be stored here. At the end of the training the model weights are also exported to a weights.pth
file inside this directory.
Logged metrics (loss, validation accuracy, validation IoU) can be inspected using tensorboard.
tensorboard --logdir output/logs/model_name
To run model inference using pretrained weights use the predict.py
script. A sample dataset config file (configs/examples.yaml
) is provided to run examples from the examples
directory.
# export CUDA_VISIBLE_DEVICES=-1 # CPU only
export CUDA_VISIBLE_DEVICES=0 # GPU to use
python predict.py \
--dataset_config configs/examples.yaml \
--model wasr_resnet101_imu \
--weights path/to/model/weights.pth \
--output_dir output/predictions
Predictions will be stored as color-coded masks to the specified output directory.
Lojze Žust & Matej Kristan. "Learning Maritime Obstacle Detection from Weak Annotations by Scaffolding." Accepted to Winter Conference on Applications of Computer Vision (WACV), 2022. [arXiv]
If you use this code, please cite our papers:
@InProceedings{Zust2022Learning,
title = {Learning Maritime Obstacle Detection from Weak Annotations by Scaffolding},
author = {Lojze \v{Z}ust and Matej Kristan},
booktitle = {WACV},
year = {2022}
}
@article{Bovcon2021WaSR,
title={WaSR--A Water Segmentation and Refinement Maritime Obstacle Detection Network},
author={Bovcon, Borja and Kristan, Matej},
journal={IEEE transactions on cybernetics}
}
[1] Bovcon, B., & Kristan, M. (2021). WaSR--A Water Segmentation and Refinement Maritime Obstacle Detection Network. IEEE Transactions on Cybernetics
[2] Bovcon, B., Muhovič, J., Perš, J., & Kristan, M. (2019). The MaSTr1325 dataset for training deep USV obstacle detection models. 2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)
[3] Bovcon, B., Muhovič, J., Vranac, D., Mozetič, D., Perš, J., & Kristan, M. (2021). MODS -- A USV-oriented object detection and obstacle segmentation benchmark. http://arxiv.org/abs/2105.02359