Implementation of Common Gradient Descent algorithm published at ICLR 2022 (https://openreview.net/forum?id=irARV_2VFs4)
All the packages are documented in the environment.yaml
file. Create a new conda env for reproducing results using the following command.
conda env create --file environment.yaml
The code is built on WILDS codebase v1.2.2 and run on TPU v3-8. For efficiency, we only release the algorithm files and detail the minimal edits to be made on the WILDS codebase below.
- Move the python files under algorithms to
examples/algorithms
of the WILDS codebase. - Edit
examples/algorithms/initializer.py
to add an import and initialization line as follows:
from algorithms.CG import CG
import numpy as np
.....
elif config.algorithm.startswith('CG'):
train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)
is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0
groups, u_counts = np.unique(train_g, return_counts=True)
g_counts = np.zeros(train_grouper.n_groups)
g_counts[groups] = u_counts
alg = CG
algorithm = alg(
config=config,
d_out=d_out,
grouper=train_grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps,
is_group_in_train=is_group_in_train,
group_info=[groups, g_counts]
)
- Add default configuration of the algorithms to
examples/configs/algorithm.py
such as the following:
'CG': {
'train_loader': 'standard',
'uniform_over_groups': True,
'distinct_groups': True,
'eval_loader': 'standard',
'cg_step_size': 0.1
},
- Add the algorithm name to algorithms variable in
examples/configs/supported.py
. - Finally add to
examples/run_expt.py
the lines below.
parser.add_argument('--cg_C', type=float, default=0)
parser.add_argument('--cg_step_size', type=float, default=0.05)
parser.add_argument('--pgi_penalty_weight', type=float)
After the edits above, we can run using:
python run_expt.py --dataset $DATASET --algorithm $ALG --root_dir data --progress_bar --log_dir logs/"$DATASET"/$ALG/run:1:seed:"$seed" --seed $seed --weight_decay 1e-4 --n_epochs 100;
In the algorithms folder, we include implementation of our method (CGD) and Ahmed et.al. ICLR 2021.
- CGD:
algorithms/cg.py
. Hyperparameters:--cg_step_size
set the step size parameter and--cg_C
sets the group adjustment parameter C discussed in our paper. - PGI (ICLR 2021):
algorithms/pgi.py
. Hyperparameter:--pgi_penalty_weight
controls ths weight of distributional divergence discussed in their paper (lambda).
Most of the datasets can be readily downloaded by passing --download
when running run_expt.py
. Additional datasets that are not part of WILDS are included in the datasets folder, these include Colored-MNIST (cmnist_debug_dataset.py
) and datasets used for qualitative evaluation of Section 4 from our paper.