/softsort

Code for "SoftSort: A Continuous Relaxation for the argsort Operator", ICML 2020.

Primary LanguagePythonMIT LicenseMIT

Code for "SoftSort: A Continuous Relaxation for the argsort Operator", ICML 2020.

This repository is a fork of ermongroup/neuralsort implementing the SoftSort operator and reproducing all the results reported in the paper "SoftSort: A Continuous Relaxation for the argsort Operator".

Requirements

The codebase is implemented in Python 3.7. To install the necessary requirements, run the following commands:

pip3 install -r requirements.txt

Sorting Handwritten Numbers Experiment

To reproduce the results in Table 1, just run:

cd tf
bash run_sort.sh
python3 run_sort_table_of_results.py

The first script (bash) will train all models. This takes a long time. You can inspect this script to see what parameters were used to train each model (which are the ones reported in the paper). The second script (python) will process the results from the models and print Table 1.

To train a single model directly, you can use the tf/run_sort.py script, with the following arguments:

  --M INT                 Minibatch size
  --n INT                 Number of elements to compare at a time
  --l INT                 Number of digits in each multi-mnist dataset element
  --tau FLOAT             Temperature (either of sinkhorn or neuralsort relaxation)
  --method STRING         One of 'deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort'
  --n_s INT               Number of samples for stochastic methods
  --num_epochs INT        Number of epochs to train
  --lr FLOAT              Initial learning rate

Quantile Regression Experiment

To reproduce the results in Table 2, just run:

cd tf
bash run_median.sh
python3 run_median_table_of_results.py

The first script (bash) will train all models. This takes a long time. You can inspect this script to see what parameters were used to train each model (which are the ones reported in the paper). The second script (python) will process the results from the models and print Table 2.

To train a single model directly, you can use the tf/run_median.py script, with the following arguments:

  --M INT                 Minibatch size
  --n INT                 Number of elements to compare at a time
  --l INT                 Number of digits in each multi-mnist dataset element
  --tau FLOAT             Temperature (either of sinkhorn or neuralsort relaxation)
  --method STRING         One of 'deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort'
  --n_s INT               Number of samples for stochastic methods
  --num_epochs INT        Number of epochs to train
  --lr FLOAT              Initial learning rate

Differentiable kNN Experiment

To reproduce the results in Table 3, run:

cd pytorch
bash run_dknn.sh
python3 run_dknn_table_of_results.py

The first script (bash) will train all the models. This takes about two days to sequentally test the different hyperparameter configurations. The seconds script iterates through logs and prints the best results.

To train a single model directly, you can use the pytorch/run_dknn.py script, with the following arguments:

  --simple                Whether to use our softsort, or the baseline neuralsort
  --k INT                 Number of nearest neighbors
  --tau FLOAT             Temperature of sorting operator
  --nloglr FLOAT          Negative log10 of learning rate
  --method STRING         One of 'deterministic', 'stochastic'
  --dataset STRING        One of 'mnist', 'fashion-mnist', 'cifar10'
  --num_train_queries INT Number of queries to evaluate during training.
  --num_train_neighbors INT Number of neighbors to consider during training.
  --num_samples INT       Number of samples for stochastic methods
  --num_epochs INT        Number of epochs to train

Speed Comparison Experiment

To reproduce the results in Figure 6, just run:

bash synthetic_experiment_speed_comparison.sh
python3 synthetic_experiment_speed_comparison_plot.py

The first script (bash) will train all models. This takes some time. The second script (python) will process the results and print the graphs in Figure 6 under the images/ directory.

Learning Curves

For the synthetic experiment learning curves, run:

bash synthetic_experiment_learning_curves.sh

Then, to generate the plot in Figure 8, run:

python3 synthetic_experiment_learning_curves_plot.py

To generate the run_sort and run_median learning curve plots (Figure 7), run:

cd tf
python3 run_sort_learning_curves.py
python3 run_median_learning_curves.py