This repo contains code implemented by Pytorch for the T.Bui et al's paper "Compact Descriptors for Sketch-based Image Retrieval using a Triplet loss Convolutional Neural Network"[Repo|Page].
The difference of the perposed network's architecture confuses me. In the paper, shown as Figure. 1, in each branch, the conv4
layer don't have ReLu
node right behind it, though, in original codes conv4
does.
I consult the original codes to build the net.
The network seems able to reproduce the results though, there is still much room for improvement IN MY CODE:
- The modified triplet loss function proposed by the paper doesn't have implementations yet. Default triplet loss function from
torch.nn.TripletMarginLoss
is used indeed. - The mentioned training policy is not implemented.
- Eval process is not included. All photographs from flickr15k are used for training. I know it's a bad idea.
- pytorch 0.4.0 with torchvision 0.2.1
- python 3.6.4
- anaconda 4.4.10 recommend
First, the pretrained model based on Flickr15k can be downloaded here. And the dataset Flickr15k can be downloaded here. Resized Flickr15k used for preview is provided here. 330sketches can be downloaded here and groundtruth is provided here. Canny edge detection procedure should be carried out to produce images' edge maps. Also, Flickr_15K_edge2 for images' edge maps is provided here.
Second, you should modify the root path to Flickr15k at ./train.py
.
The output of the model will be stored at ./out/flickr15k_yymmddHHMM/*.pth
.
The default root path is ../deep_hashing
according to my case.
Third, run ./train.py
to train the network. Use ./extract_feat_sketch.py
and ./extract_feat_photo.py
to extract features from sketches and photograps.
The features will be stored at ./out_feat/flickr15k_yymmddHHMM/feat_sketch.npz
and ./out_feat/flickr15k_yymmddHHMM/feat_photo.npz
.
Last, use ./retrieval.py
to gain results. The retrieval list will be stored at ./out_feat/flickr15k_yymmddHHMM/result
.
To be consistent with 330sketches query's file structure, results of every query are saved in group and sorted by similariy.
.
├── accessory
│ └── pr_curve.png
├── dataset
│ ├── 330sketches
│ ├── groundtruth
│ └── Flickr_15K_edge2
├── extract_feat_photo.py
├── extract_feat_sketch.py
├── flickr15k_dataset.py
├── model
│ ├── SketchTriplet_half_sharing.py
│ └── SketchTriplet.py
├── out
│ └── flickr15k_1904041458
│ ├── 500.pth
│ └── loss_and_accurary.txt
├── out_feat
│ └── flickr15k_1904041458
│ ├── feat_photo.npz
│ ├── feat_sketch.npz
│ └── result
├── README.md
├── retrieval.py
├── train.py
└── utils.py
We will train the network SketchTriplet on the dataset Flickr15k. The network takes an anchor (sketch input), positive (a photograph edgemap of same class as an anchor) and negative (photograph edgemaps of different class than an anchor) examples.
Some Parameters are shown as follows:
- Edge extraction algorithm: Canny edge detection
- Batch size: 128
- Number of epochs: 500
- Optimizer: torch.optim.SGD
- learning rate: 1e-3
- weight decay: 0.0005
- momentum: 0.9
- Loss function: torch.nn.TripletMarginLoss
- margin: 1.0
- p: 2.0
After 500 epochs of training, here are the pr curve we get for testing set.
Pr curve for testing
Also the loss curve during training is shown as follows.
Loss curve during training
Although it scores 67.7% mAP indicating just-so-so performance, the pr curve shows the model is over-fitting.
- Modified triplet loss function mentioned by the paper
- Eval process. 75% for training, 25% for evaluation at least
- Some silly code need be removed, refactoring also
- Extract features parallel
- Fix typos
[1] adambielski's siamese-triplet repo
[2] weixu000's DSH-pytorch repo
[3] TuBui's Triplet_Loss_SBIR repo