A Contrastive Objective for Disentangled Representations
Joanthan Kahana and Yedid Hoshen
Official PyTorch Implementation
Abstract: Learning representations of images that are invariant to sensitive or unwanted attributes is important for many tasks including bias removal and cross domain retrieval. Here, our objective is to learn representations that are invariant to the domain (sensitive attribute) for which labels are provided, while being informative over all other image attributes, which are unlabeled. We present a new approach, proposing a new domain-wise contrastive objective for ensuring invariant representations. This objective crucially restricts negative image pairs to be drawn from the same domain, which enforces domain invariance whereas the standard contrastive objective does not. This domain-wise objective is insufficient on its own as it suffers from shortcut solutions resulting in feature suppression. We overcome this issue by a combination of a reconstruction constraint, image augmentations and initialization with pre-trained weights. Our analysis shows that the choice of augmentations is important, and that a misguided choice of augmentations can harm the invariance and informativeness objectives. In an extensive evaluation, our method convincingly outperforms the state-of-the-art in terms of representation invariance, representation informativeness, and training speed. Furthermore, we find that in some cases our method can achieve excellent results even without the reconstruction constraint, leading to a much faster and resource efficient training.
This repository is the official PyTorch implementation of A Contrastive Objective for Disentangled Representations
By default the <base-dir>
directory is the main directory of the repository, although it can be changed in the code itself.
Please create a directory <base-dir>/pretrained_weights
and put the ImageNet pre-trained weights in it.
MocoV2 weights can be downloaded from here or from the official github page
NOTE: you need to download the datasets first.
We provide pre-processed versions of the datasets. They are found in here.
Please put the pre-processed versions under cache/preprocess
.
NOTE: As mentions above you can download the pre-processed versions from here.
We also supply scripts for creating the pre-processed versions.
- Edges2Shoes can be downloaded by running the given script
scripts/downlod_e2s_zappos.sh
- Cars3D can be downloaded by the documentation of disentanglement_lib
- Shapes3D can be downloaded from here. Please put it under
$DISENTANGLEMENT_LIB_DATA/3dshapes/
- CelebA can be downloaded from here. Please download its files under
raw_data/celeba
In case you download the datasets to other locations, make sure to update the path in the beginning of the corresponding preprocessing script before running it.
The preprocessing can be applied by:
scripts/prepare_#DATASET_NAME#.py
Given a preprocessed train set and test set as the scripts create,
Training a dataset can be done by running one of the attached bash scripts in the bash_scripts folder, according to the desired experiment.
To train DCoDR on smallnorb for example, simply run:
bash bash_scripts/DCoDR/smallnorb/DCoDR__smallnorb__pipeline.sh
We provide trained models for all of the evaluated datasets from the main experiment in the paper.
Please download model .pth
files as well as the config.pkl
file which is needed for evaluation.
Dataset | DCoDR-norec | DCoDR |
---|---|---|
Cars3D | DCoDR-norec Cars3d | DCoDR Cars3D |
SmallNorb | DCoDR-norec SmallNorb | DCoDR SmallNorb |
CelebA | DCoDR-norec CelebA | DCoDR CelebA |
Edges2Shoes | DCoDR-norec Edges2Shoes | DCoDR Edges2Shoes |
Shapes3D | DCoDR-norec Shapes3D | DCoDR Shapes3D |
If you find this useful, please cite our paper:
@inproceedings{kahana2022dcodr,
author = {Kahana, Jonathan and Hoshen, Yedid},
booktitle = {European Conference on Computer Vision (ECCV) },
title = {A Contrastive Objective for Disentangled Representations},
year = {2022}
}