TBA
TBA
- Download the ped2, avenue, CalTech datasets.
USCD Ped2 | CUHK Avenue | CalTech Pedestrian Dataset |
---|---|---|
Google Drive | Google Drive | Google Drive |
# Train by default with specified dataset.
python train_gen.py --dataset=avenue
# Train with different batch_size, you might need to tune the learning rate by yourself.
python train_gen.py --dataset=avenue --batch_size=16
# Set the max training iterations.
python train_gen.py --dataset=avenue --iters=80000
# Set the save interval and the validation interval.
python train_gen.py --dataset=avenue --save_interval=2000 --val_interval=2000
# Resume training with the latest trained model or a specified model.
python train_gen.py --dataset=avenue --resume latest [or avenue_10000.pth]
# Finetuning by default with specified dataset.
python ft_dqn.py --dataset=CalTech --resume_g=CalTech_10000.pth
# Finetuning with different batch_size, you might need to tune the learning rate by yourself.
python ft_dqn.py --dataset=CalTech --resume_g=CalTech_10000.pth --batch_size=16
# Set the max training iterations.
python ft_dqn.py --dataset=CalTech --resume_g=CalTech_10000.pth --iters=80000
# Set the save interval and the validation interval.
python ft_dqn.py --dataset=CalTech --resume_g=CalTech_10000.pth --save_interval=2000 --val_interval=2000
# Resume training with the trained RL model.
python ft_dqn.py --dataset=CalTech --resume_g=CalTech_10000.pth --resume_r=ft_CalTech_10000.pth
tensorboard --logdir=tensorboard_log/ped2_bs4
# Validate with a trained model.
python evaluate_ft.py --dataset=CalTech --trained_model=CalTech_10000.pth