/Rank-N-Contrast

[NeurIPS 2023, Spotlight] Rank-N-Contrast: Learning Continuous Representations for Regression

Primary LanguagePython

Rank-N-Contrast: Learning Continuous Representations for Regression

Paper | BibTex

Rank-N-Contrast: Learning Continuous Representations for Regression
Kaiwen Zha*, Peng Cao*, Jeany Son, Yuzhe Yang, Dina Katabi (*equal contribution)
NeurIPS 2023 (Spotlight)

Loss Function

The loss function RnCLoss in loss.py takes features and labels as input, and return the loss.

from loss import RnCLoss

# define loss function with temperature, label difference measure, 
# and feature similarity measure
criterion = RnCLoss(temperature=2, label_diff='l1', feature_sim='l2')

# features: [bs, 2, feat_dim]
features = ...
# labels: [bs, label_dim]
labels = ...

# compute RnC loss
loss = criterion(features, labels)

Running

Download AgeDB dataset from here and extract the zip file (you may need to contact the authors of AgeDB dataset for the zip password) to folder ./data.

  • To train the model with the L1 loss, run

    python main_l1.py
    
  • To train the model with the RnC framework, first run

    python main_rnc.py
    

    to train the encoder. The checkpoint of the encoder will be saved to ./save. Then, run

    python main_linear.py --ckpt <PATH_TO_THE_TRAINED_ENCODER_CHECKPOINT>
    

    to train the regressor.

Citation

If you use this code for your research, please cite our paper:

@inproceedings{zha2023rankncontrast,
    title={Rank-N-Contrast: Learning Continuous Representations for Regression},
    author={Zha, Kaiwen and Cao, Peng and Son, Jeany and Yang, Yuzhe and Katabi, Dina},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023}
}