Source codes for the experiments in A Unified Framework for Alternating Offline Model Training and Policy Learning. [Paper], [Poster], [Slides].
Bibtex:
@inproceedings{yang2022unified,
title={A Unified Framework for Alternating Offline Model Training and Policy Learning},
author={Shentao Yang and Shujian Zhang and Yihao Feng and Mingyuan Zhou},
booktitle={Advances in Neural Information Processing Systems},
year={2022},
url={https://arxiv.org/abs/2210.05922}
}
- Install basic packages, using e.g.,
conda create -n ampl python=3.8.5
conda activate ampl
pip install numpy matplotlib seaborn gym==0.17.0 torch==1.10.1 cudatoolkit==11.1.74
and adding other possible dependencies. 2. Install MuJoCo and mujoco-py. 3. Install D4RL.
The run files to run the experiments are generated by the submit_jobs_server_gan.py
file.
An example use of this file is
python submit_jobs_server_gan.py
Flags can be provided to the python
command.
Please take a look at this file for available flags.
The location of the generated run files will be printed out.
The run files will generate a folder for each (dataset, seed)
pair.
Within a such folder, the file eval_norm.npy
stores the normalized scores and eval.npy
records the unnormalized scores.
The normalized scores are calculated by the D4RL package.
Below lists the commands for the variants used in our ablation study.
- No weighted model-retraining (train the model only once in the beginning using MLE)
python submit_jobs_server_gan.py --model_retrain_period=1000
- Use VPM to train the MIW model
python submit_jobs_server_gan.py --dr_method="VPM" --weight_output_clipping="True"
- Use GenDICE to train the MIW model
python submit_jobs_server_gan.py --dr_method="GenDICE" --weight_output_clipping="True"
- Use DualDICE to train the MIW model
python submit_jobs_server_gan.py --dr_method="DualDICE" --weight_output_clipping="True"
- Use weighted policy-regularizer
python submit_jobs_server_gan.py --weighted_policy_training='True'
- KL-Dual + weighted policy-regularizer
python submit_jobs_server_gan.py --weighted_policy_training='True' --use_kl_dual='True' --use_weight_wpr='True'
- KL-Dual + No-weighted policy-regularizer
python submit_jobs_server_gan.py --weighted_policy_training='True' --use_kl_dual='True' --use_weight_wpr='False'
- Gaussian policy + JSD for the policy training
python submit_jobs_server_gan.py --use_gaussian_policy='True'
- No regularization in the policy training
python submit_jobs_server_gan.py --remove_reg='True'
- No model-rollout data (
real_data_pct=1
)
python submit_jobs_server_gan.py --real_data_pct=1.
- Use reward function as the test function in training the MIW
python submit_jobs_server_gan.py --use_reward_test_func='True'
- Use value-function as the discriminator to train the model
python submit_jobs_server_gan.py --q_dis_model='True'
MIT License.