/SoftTriple

PyTorch Implementation for SoftTriple Loss

Primary LanguagePythonApache License 2.0Apache-2.0

SoftTriple Loss

PyTorch Implementation for Our ICCV'19 Paper: "SoftTriple Loss: Deep Metric Learning Without Triplet Sampling"

Usage: Train on Cars196

Here is an example of using this package.

  1. Obtain dataset
wget http://imagenet.stanford.edu/internal/car196/car_ims.tgz
tar -xf car_ims.tgz
  1. Generate train/test sets
python genCars.py
  1. Learn 64-dimensional embeddings
python train.py --gpu 0 --dim 64 -C 98 --freeze_BN [folder with train and test folders]
python train.py --gpu 0 --dim 512 -C 2468 --freeze_BN --train_name train_small --test_name val1_small ../../datasets/hotels50k_v5_restructured/

Requirements

  • Python 3.7
  • PyTorch 1.1
  • scikit-learn 0.20.1

Citation

If you use the package in your research, please cite our paper:

@inproceedings{qian2019striple,
  author    = {Qi Qian and
               Lei Shang and
               Baigui Sun and
               Juhua Hu and
               Hao Li and
               Rong Jin},
  title     = {SoftTriple Loss: Deep Metric Learning Without Triplet Sampling},
  booktitle = {{IEEE} International Conference on Computer Vision, {ICCV} 2019},
  year      = {2019}
}