SENTRY standsfor Selective ENTRopY Optimization via Committee Consistency for Unsupervised Domain Adaptation. It attempts to achieve domain adaptation under label distribution shift (LDS). SENTRY was introduced in this paper:
@InProceedings{Prabhu_2021_ICCV,
author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy},
title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2021},
pages = {8558-8567}
}
This SENTRY implementation uses a 2D toy dataset under label distribution shift (LDS). We used built-in plots that help to visualize how the SENTRY algorithm is learning the new features.
The code starts by retrieving source dataset
from data folder. Then it performs a rotation (domain shift) on a copy of the dataset. The rotated dataset is the target dataset
, with unbalanced label distribution. Here is a visualization of source and target datasets:
The domain adaptation takes place in core.train_tgt
function. SENTRY attempt to minimize three loss functions:
loss_CE
cross-entropy loss with respect to source ground-truth labels.loss_IE
information-entropy loss computed over classes predicted by the model for the last-Q target instancesloss_SENTRY
selective entropy optimization:- minimizing predictive entropy with respect to the current target sample and one of its consistent versions.
- maxiimizing predictive entropy with respect to the current target sample and one of its inconsistent versions.
The original paper used RandAugment to compute augmented versions of the current target sample. Since we're using 2D points, we cannot use image transformations in RandAugment. Therefore, we used k-nearest neighbor to find the nearest k points to the current target sample. We set k=7.
The goal is to train the feature_extractor
to learn features for both source
and target
smaples. Here is the feature_extractor
performance on source
samples:
Avg Loss = 0.39374, Avg Accuracy = 90.500000%, ARI = 0.65438
Now, we used the same feature_extractor
to classify target
samples. Note that we still did not perform domain adaptation:
Avg Loss = 0.50390, Avg Accuracy = 83.333333%, ARI = 0.40878
After performing domain adaptation in core.train_tgt
function, we can use the feature_extractor
to classify target
samples:
Avg Loss = 0.20647, Avg Accuracy = 90.833333%, ARI = 0.60555
I reused some code from the original repository provided by the authors.