/penalized-bilevel-gradient-descent

An implementation of the penalty-based bilevel gradient descent (PBGD) algorithm and the iterative differentiation (ITD/RHG) methods.

Primary LanguagePython

Introduction

This repo includes an implementation of the penalty-based bilevel gradient descent (PBGD) algorithm presented in the paper On Penalty-based Bilevel Gradient Descent Method, along with several other baseline algorithms.

The algorithms solve the bilevel optimization problem: $$\min_{x,y}f(x,y)~{\rm s.t. }~y\in\arg\min_y g(x,y).$$ The bilevel (optimization) problem enjoys a wide range of applications; e.g., meta-learning, image processing, hyper-parameter optimization, and reinforcement learning.

Implemented algorithms

Dependencies

The combination below works for us.

Running the code

Toy problem

The problem is described in the 'numerical verification' section of the paper.

To recover the result, navigate to ./V-PBGD/toy/ and run in console:

python toy.py

Left: plot of the hyper-objective (dashed line). Right: Red points are last iterates generated by PBGD with 1000 random initialized points. PBGD finds the local solutions of the hyper-objective.

Data hyper-cleaning

The problem is described in the 'Data hyper-cleaning' section of the paper.

To run V-PBGD, navigate to ./V-PBGD/data-hyper-cleaning/ and run either line:

python data_hyper_clean.py 

python data_hyper_clean.py --net MLP --lrx 0.1 --lry 0.01 --lr_inner 0.01 --gamma_max 0.1 --gamma_argmax_step 10000 --outer_itr 80000

To run G-PBGD, navigate to ./G-PBGD/ and run either line:

python data_hyper_clean_gpbgd.py

python data_hyper_clean_gpbgd.py --net MLP --outer_itr 50000 --lrx 0.5 --lry 0.5 --gamma_max 37 --gamma_argmax_step 30000  

To run RHG, navigate to ./RHG/ and run either line:

python data_hyper_clean_rhg.py 

python data_hyper_clean_rhg.py --net MLP --lr_inner 0.4

To run T-RHG, navigate to ./RHG/ and run either line:

python data_hyper_clean_rhg.py --K 100

python data_hyper_clean_rhg.py --net MLP --K 100 --lr_inner 0.4

Citation

If you find this repo helpful, please cite the paper.

Processing