- Python 3.10.4
- Tensorflow 2.10.0
cited from https://arxiv.org/pdf/1609.04802.pdf
Make the following directory tree for your dataset on the project root and place original images in train/high_resolution
and validate/high_resolution/
directories.
.datasets
└── (your dataset name)
├── test
│ ├── high_resolution
│ └── low_resolution
├── train
│ ├── high_resolution
│ └── low_resolution
└── validate
├── high_resolution
└── low_resolution
Next, make low resolution images which have quarter size of original ones and place them in low_resolution
directories.
This program request TFRecords as dataset. I prepare a function for you. Fix dataset_name
and extension
in the src/datasets.py and execute it from project root.
poetry run python src/dataset.py
Make sure there exists train.tfrecords
and valid.tfrecords
in the datasets/(your dataset name)
directory.
The parameters like hyper-parameters are set in the config.yaml
config.yaml
TYPE: SRResNet
EPOCHS: 10000
BATCH_SIZE: 16
IMG_HEIGHT: 32
IMG_WIDTH: 32
LEARNING_RATE: 0.0001
TRAIN_DATA_PATH: ./datasets/train.tfrecords
VALIDATE_DATA_PATH: ./datasets/valid.tfrecords
CHECKPOINT_PATH: ./checkpoint/generator_train
START_EPOCH: 0
GEN_WEIGHT:
DISC_WEIGHT:
G_LOSS: 100000000
Start training with the following command
$ poetry run python src/train.py
config.yaml
TYPE: SRGAN
EPOCHS: 10000
BATCH_SIZE: 16
IMG_HEIGHT: 32
IMG_WIDTH: 32
LEARNING_RATE: 0.0001
TRAIN_DATA_PATH: ./datasets/train.tfrecords
VALIDATE_DATA_PATH: ./datasets/valid.tfrecords
CHECKPOINT_PATH: ./checkpoint/gan_train
START_EPOCH: 0
GEN_WEIGHT: ./generator_train/generator_best
DISC_WEIGHT:
G_LOSS: 100000000
Start training with the following command
$ poetry run python train
You can download pre-trained weight from here. These weights are trained with 32x32 images.