PyTorch implementation for "Gated Transfer Network for Transfer Learning"

Primary LanguagePython

This is the official PyTorch implementation of the paper:

Gated Transfer Network for Transfer Learning

Yi Zhu and Jia Xue and Shawn Newsam

ACCV 2018


We recommend using a Conda environment. We use PyTorch 1.1, CUDA 9.0 and python 3.7.

conda create -n gtn python=3.7
conda activate gtn
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
pip install easydict

Data Preparation

Please see datasets README for more details.


We take CUB200 as an example in the experiments folder, other experiments are similar except some hyper-parameter changes.

  • Set config.py correctly (dataset path, hyper-paramters, etc.)

  • python train.py

  • Evaluation is done on-the-fly.

Note that, the evaluation performance on UCF101 is not the final results because it is a video dataset. If you need the final clip-level results, you need to perform aggregation (example script can be found here).


If you use this code for your research, please consider citing our paper:

  author    = {Yi Zhu and Jia Xue and Shawn Newsam},
  title     = {Gated Transfer Network for Transfer Learning},
  booktitle = {Asian Conference on Computer Vision (ACCV)},
  year      = {2018}