NOTE This repo is based on this Single Image Super Resolution Repo code
This repo trains models to do image restoration on DIV2K dataset. It supports all modes of distortion in DIV2K dataset including bicubic
, unknown
, mild
, and difficult
. It also restores different downscaling factors 2
, 3
, 4
, and 8
. The parameters of the training can be specified so that the model trains on a dataset, and is evaluated on a different one.
create a virtual environment
python3 -m venv .venv
Activate the virtual environment
source .venv/bin/activate
Install the dependencies
pip install -r requirements.txt
For training and validation on DIV2K images, applications should use the
provided DIV2K
data loader. It automatically downloads DIV2K images to .div2k
directory and converts them to a
different format for faster loading.
A DIV2K
data provider automatically downloads DIV2K
training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or
"difficult").
from data import DIV2K
train_loader = DIV2K(scale=4, # 2, 3, 4 or 8
downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult'
subset='train') # Training dataset are images 001 - 800
# Create a tf.data.Dataset
train_ds = train_loader.dataset(batch_size=16, # batch size as described in the EDSR and WDSR papers
random_transform=True, # random crop, flip, rotate as described in the EDSR paper
repeat_count=None) # repeat iterating over training images indefinitely
from data import DIV2K
valid_loader = DIV2K(scale=4, # 2, 3, 4 or 8
downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult'
subset='valid') # Validation dataset are images 801 - 900
# Create a tf.data.Dataset
valid_ds = valid_loader.dataset(batch_size=1, # use batch size of 1 as DIV2K images have different size
random_transform=False, # use DIV2K images in original size
repeat_count=1) # 1 epoch
The repo supports two model architectures.
- Enhanced Deep Residual Networks for Single Image Super-Resolution (EDSR)
- Wide Activation for Efficient and Accurate Image Super-Resolution (WDSR)
To run training use:
python train_image_resolution_model.py --scale 2
--downgrade bicubic
--model edsr
--batch-size 16
--depth 16
--scale_val 4
--downgrade_val mild
model
: Model architecture, can beedsr
orwdsr
. For quantized models,qedsr
can be used.downgrade
: Distortion type for training, can bebicubic
,unknown
,mild
, ordifficult
scale
: Downsampling factor for training, can be2
,3
,4
, or8
depth
: Depth of the model, default16
batch-size
: Training batch size, default16
downgrade_val
: Distortion type for validation, can bebicubic
,unknown
,mild
, ordifficult
scale_val
: Downsampling factor for validation, can be2
,3
,4
,8
pretrained
: Path to the pretrained model.eval_all_distortions
: Evaluate the model on all distortion types. Note that, only scale4
is supported.train_all_distortions
: Train the model on all distortion types. Note that, only scale4
is supported.precision
: The precision of the model's weights. Can be only used whenmodel=edsr
.
The result model is stored under weight/
directory.
Model | Distortion | Downscaling Factor | PSNR | Precision |
---|---|---|---|---|
EDSR | Bicubic | 2 | 31.53 | FP32 |
EDSR | Bicubic | 4 | 26.98 | FP32 |
EDSR | Mild | 4 | 18.77 | FP32 |
EDSR | Difficult | 4 | 19.18 | FP32 |
EDSR | Unknown | 4 | 22.85 | FP32 |