/GTN

This repository provides code for all the results reported in the GTN paper.

Primary LanguagePythonMIT LicenseMIT

GTN based on A2C

This repo provides code for a PyTorch implementation of GTN based on A2C. See Generalization Tower Network: A Novel Deep Neural Network Architecture for Multi-Task Learning.

Supported (and tested) environments (via OpenAI Gym)

All environments are operated using exactly the same Gym interface. See their documentations for a comprehensive list.

Requirements

Pre-requirements

Refer to my personal basic setup for some convinient command lines.

Other requirements

In order to install other requirements, follow:

# clear env
source ~/.bashrc
source deactivate
conda remove --name gtn_env --all

# create
conda create -n gtn_env

# source in
source ~/.bashrc
source activate gtn_env

# clear dir
rm -r gtn_env

# create dir
mkdir -p gtn_env/project/ && cd gtn_env/project/

# PyTorch
pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl 
pip install torchvision
pip install visdom
pip install numpy -I
pip install gym[atari]

# Other requirements
git clone https://github.com/YuhangSong/gtn_a2c.git
cd gtn_a2c
pip install -r requirements.txt
cd ..

Meet some issues? See problems

Run the code

####Start a Visdom server with

source ~/.bashrc
source activate gtn_env
python -m visdom.server

it will serve http://localhost:8097/ by default.

Run GTN based on A2C with Atari domain.

source ~/.bashrc
source activate gtn_env
CUDA_VISIBLE_DEVICES=0 python main.py

Results

XX

BreakoutNoFrameskip-v4

Contributions

Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request. Also see a todo list below.

TODO

  • Improve this README file. Rearrange images.

Problems

conda `GLIBCXX_3.4.20' not found

is solved by

source ~/.bashrc
source deactivate
conda install libgcc
source ~/.bashrc
source activate gtn_env
Traceback (most recent call last):
  File "main.py", line 7, in <module>
    import torch
  File "/home/yuhangsong/anaconda3/lib/python3.6/site-packages/torch/__init__.py", line 53, in <module>
    from torch._C import *
ImportError: numpy.core.multiarray failed to import

is solved by

source ~/.bashrc
source activate gtn_env
pip install numpy -I