In the field of artificial intelligence, one of the primary challenges is ensuring that models perform robustly under unexpected conditions. Traditional object detection models, such as Faster R-CNN, struggle with detecting Out-Of-Distribution (OOD) data, often leading to incorrect detections or misclassifications. Inspired by CLIPN, we developed the CLOUD-Contrastive Learning Based Out-of-Distribution Unified Detector. This model applies OOD detection capability in multi-object detection scenarios.
Our contribution includes:
- the creation of a new dataset
- the development of a joint training pipeline
- and the implementation of region-text matching techniques alongside new loss strategies to enhance overall model performance
git clone https://github.com/Andy-wyx/cloud.git
cd cloud
conda create -n cloud python=3.9
conda activate cloud
pip install -r ./requirements.txt
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
torch and cudatoolkit version may depend on your GPU Driver version. Refer to this page for previous versions.
- Step1: download coco dataset from here The expected structure is:
$coco/
annotations/
train2017/
val2017
- Step2:Then run the python script to generate our dataset:
cd src/preprocess
python data_gen.py
- Step3: redefine dataset details in ./src/tuning_util.py
Side notes:
- train2017 is better for region-text matching
- val2017 is suitable for finetuning CLOUD with the feature distance loss. It is also faster for lightweight playing around.
- if you also want to try traditional OOD datasets e.g. iNaturalist, Textures, Places, SUN. See download instructions here
Pretrain by:
cd src
sh run.sh
You may need to adjust the parameters to fed in necessary parameters corresponding to the training you intend to conduct and the number of GPU devices you're using.
Inference (OOD detection network only) by:
cd src
python zero_shot_infer.py
Inference (CLOUD) by:
cd src
python zero_shot_infer_cloud.py
After these steps, the ./logs folder will provide you with model checkpoints, inference results in csv and JSON file for further visualization.
The main implementation of the visualizer is based on the detectron2 - we parse the json file first then visualize it.
We also visualize our own RPN results here:
Refer to our report