PyTorch Colab notebook: ARShadowGAN-like
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
- Clone this repo:
git clone https://github.com/Everypixel/arshadowgan-like.git
cd arshadowgan
- Install dependencies (e.g., segmentation_models_pytorch, ...)
pip install -r requirements.txt
We will use the shadow-ar dataset for training and testing our model. We have already splitted it to train and test parts. Download and extract it please .
Your own dataset has to have the structure such ShadowAR-dataset has. Each folder contains images.
dataset ├── train │ ├── noshadow ── example1.png, ... │ ├── shadow ──── example1.png, ... │ ├── mask ────── example1.png, ... │ ├── robject ─── example1.png, ... │ └── rshadow ─── example1.png, ... └── test ├── noshadow ── example2.png, ... ├── shadow ──── example2.png, ... ├── mask ────── example2.png, ... ├── robject ─── example2.png, ... └── rshadow ─── example2.png, ...
- noshadow - no shadow images
- shadow - images with shadow
- mask - inserted object masks
- robject - occluders masks
- rshadow - occluders shadows
Set the parameters:
- dataset_path - path to dataset
- model_path - path for attention model saving
- batch_size - amount of images in batch
(reduce it if "CUDA: out of memory" error) - seed - seed for random functions
- img_size - image width or image height (is divisible by 32)
- lr - learning rate
- n_epoch - amount of epochs
For example:
python3 scripts/train_attention.py \
--dataset_path '/content/arshadowgan/dataset/' \
--model_path '/content/drive/MyDrive/attention128.pth' \
--batch_size 200 \
--seed 42 \
--img_size 256 \
--lr 1e-4 \
--n_epoch 100
- dataset_path - path to dataset
- Gmodel_path - path for generator model saving
- Dmodel_path - path for discriminator model saving
- batch_size - amount of images in batch
(reduce it if "CUDA: out of memory" error) - seed - seed for random functions
- img_size - image width or image height (is divisible by 32)
- lr_G - generator learning rate
- lr_D - discriminator learning rate
- n_epoch - amount of epochs
- betta1,2,3 - loss function coefficients, see ARShadowGAN paper
For example:
python3 scripts/train_SG.py \
--dataset_path '/content/arshadowgan/dataset/' \
--Gmodel_path '/content/drive/MyDrive/SG_generator.pth' \
--Dmodel_path '/content/drive/MyDrive/SG_discriminator.pth' \
--batch_size 64 \
--seed 42 \
--img_size 256 \
--lr_G 1e-4 \
--lr_D 1e-6 \
--n_epoch 600 \
--betta1 10 \
--betta2 1 \
--betta3 1e-2 \
--patience 10 \
--encoder 'resnet18'
Start inference with results saving
For example:
python3 scripts/test.py \
--batch_size 1 \
--img_size 256 \
--dataset_path '/content/arshadowgan/dataset/test' \
--result_path '/content/arshadowgan/results' \
--path_att '/content/drive/MyDrive/ARShadowGAN-like/attention.pth' \
--path_SG '/content/drive/MyDrive/ARShadowGAN-like/SG_generator.pth'
We thank ARShadowGAN authors for their amazing work.
We also thank segmentation_models.pytorch for network architecture, albumentations for augmentations, PyTorch-GAN for discriminator architecture and piq for Content loss.