Based on the Contrastive Unpaired Translation (CUT) FastCUT model and using the OpenAI simulation environment.
To pre-train the edge detection network use:
styletransfer/cut/train.py -dataroot ./datasets/road_tree_new_train --name new_trees/edge_detection_MSE --batch_size 32 --dataset_mode "conditional" --model "edge" --n_epochs 150 --display_freq 100 --output_nc 1 --ngf 16 --edge_loss "MSE"
The edge detection model has to be present in the cut/checkpoint/experiment_name folder before training the style transfer model. To train the dual style transfer model use:
styletransfer/cut/train.py --dataroot ./datasets/road_tree_new_train --name new_trees/styletransfer_MSE_histo_10_edges_10 --CUT_mode FastCUT --batch_size 4 --dataset_mode "conditional" --model "conditional_cut" --netG "conditional_resnet_9" --netD "conditional" --display_freq 100 --lambda_hist 10 --lambda_edge 10 --edge_loss "MSE" --n_epochs 50
To run the simulation use
simulation/play_evaluation.py
out.mp4
GradCAM.mp4
On Ubuntu 16.04 and 18.04:
apt-get install -y libglu1-mesa-dev libgl1-mesa-dev libosmesa6-dev xvfb ffmpeg curl patchelf libglfw3 libglfw3-dev cmake zlib1g zlib1g-dev swig
Create conda environment with
conda env create [--name envname] -f environment.yml
conda activate [envname or DL_CAR]
[ ...] denotes optional
then install gym
pip install gym
pyglet
pip install --upgrade pyglet
Install pytorch and torchvision
dominate
pip install dominate
visdom
pip install visdom
openCV
pip install opencv-python
scikit-image
conda install scikit-image
Optional for visualization:
plotly with jupyter lab support:
conda install -c plotly plotly
conda install jupyterlab "ipywidgets>=7.5"
jupyter labextension install jupyterlab-plotly@4.14.3
jupyter labextension install @jupyter-widgets/jupyterlab-manager plotlywidget@4.14.3
pip install pykeops
pip install geomloss
pip install kornia
conda install -c conda-forge ipympl
conda install -c conda-forge nodejs
jupyter labextension install @jupyter-widgets/jupyterlab-manager jupyter-matplotlib
FID comparrison
pip install pytorch-fid
GradCAM Class Activation Map methods implemented in Pytorch
pip install grad-cam
SSIM loss https://github.com/Po-Hsun-Su/pytorch-ssim is included.