Deep Learning for Smartphone ISP
Overview
[Challenge Website] [Workshop Website]
This repository provides the implementation of the baseline model, PUNET, for the Learned Smartphone ISP Challenge in Mobile AI (MAI) Workshop @ CVPR 2021. The model is trained to convert RAW Bayer data obtained directly from mobile camera sensor into photos captured with a professional Fujifilm DSLR camera, thus replacing the entire hand-crafted ISP camera pipeline. The provided pre-trained PUNET model can be used to generate full-resolution 12MP photos from RAW image files captured using the Sony IMX586 camera sensor. PUNET is a UNet-like architecture modified from PyNET and serve as an extension to the PyNET project.
Contents:
- Overview
- Prerequisites
- Dataset and model preparation
- Learned ISP Pipeline
- Training
- Test/Inference
- Convert checkpoint to pb
- Convert pb to tflite
- [Optional] Some useful tools
- Results
- Folder structure (default)
- Model optimization
- Common FAQ
- Acknowledge
- Citation
- License
- Contact
Prerequisites
- Python: numpy, scipy, imageio and pillow packages
- TensorFlow 1.15.0 + CUDA cuDNN
- GPU for training (e.g., Nvidia GeForce GTX 1080)
Dataset and model preparation
-
Download Mediatek's pre-trained PUNET model and put it into
models/original/
folder. -
Download training data and extract it into
raw_images/train/
folder. -
Download validation data and extract it into
raw_images/val/
folder. -
Download testing data and extract it into
raw_images/test/
folder.
The dataset folder (default name:raw_images/
) should contain three subfolders:train/
,val/
andtest/
. Please find the download links to above files in MAI'21 Learned Smartphone ISP Challenge website (registration needed). -
[Optional] Download pre-trained VGG-19 model Mirror and put it into
vgg_pretrained/
folder.
The VGG model is used for one of the loss functionsloss_content
in the baseline, which takes the output of PUNET as the input. You are free to remove that loss (line 65-72 intrain_model.py
). This may affect the result PSNR, but won't affect the whole pipeline.
Learned ISP Pipeline
The whole pipeline of Learned Smartphone ISP has two main steps (assume the input resolution is H x W
):
- deBayer pre-processing (in
load_dataset.py
):- Input: RAW data [
H x W x 1
] - Output: deBayer RAW data [
(H/2) x (W/2) x 4
] - You are free to modify the pre-processing method as long as the input & output shapes are kept.
- Input: RAW data [
- PUNET model (in
model.py
): PUNET is a UNet-like architecture modified from PyNET.- Input: deBayer RAW data [
(H/2) x (W/2) x 4
] - Output: RGB image [
H x W x 3
] - [Important] The submitted TFLite model to the Learned Smartphone ISP Challenge @ MAI 2021 is required to have the same input & output shapes as PUNET. Please check the challenge website for more details.
- Input: deBayer RAW data [
Training
Start training
To train the model, use the following command:
python train_model.py
Optional parameters (and default values):
dataset_dir
:raw_images/
- path to the folder with the dataset
model_dir
:models/
- path to the folder with the model to be restored or saved
vgg_dir
:vgg_pretrained/imagenet-vgg-verydeep-19.mat
- path to the pre-trained VGG-19 network
dslr_dir
:fujifilm/
- path to the folder with the RGB data
phone_dir
:mediatek_raw/
- path to the folder with the Raw data
arch
:punet
- architecture name
num_maps_base
:16
- base channel number (e.g. 8, 16, 32, etc.)
restore_iter
:None
- iteration to restore
patch_w
:256
- width of the training images
patch_h
:256
- height of the training images
batch_size
:32
- batch size [small values can lead to unstable training]
train_size
:5000
- the number of training patches randomly loaded each 1000 iterations
learning_rate
:5e-5
- learning rate
eval_step
:1000
- eacheval_step
iterations the accuracy is computed and the model is saved
num_train_iters
:100000
- the number of training iterations
Below we provide an example command used for training the PUNET model on the Nvidia GeForce GTX 1080 GPU with 8GB of RAM.
CUDA_VISIBLE_DEVICES=0 python train_model.py \
model_dir=models/punet_MAI/ arch=punet num_maps_base=16 \
patch_w=256 patch_h=256 batch_size=32 \
eval_step=1000 num_train_iters=100000
After training, the following files will be produced under model_dir
:
checkpoint
- contain all the checkpoint names
logs_[restore_iter]-[num_train_iters].txt
- training log (including loss, PSNR, etc.)
[arch]_iteration_[iter].ckpt.data
- part of checkpoint data for the model[arch]_iteration_[iter]
[arch]_iteration_[iter].ckpt.index
- part of checkpoint data for the model[arch]_iteration_[iter]
Resume training
To resume training from restore_iter
, use the command like follows:
CUDA_VISIBLE_DEVICES=0 python train_model.py \
model_dir=models/punet_MAI/ arch=punet num_maps_base=16 \
patch_w=256 patch_h=256 batch_size=32 \
eval_step=1000 num_train_iters=110000 restore_iter=100000
Test/Inference
test_model.py
runs a model on testing images with the height=img_h
and width=img_w
. Here we use img_h=1088
and img_w=1920
as the example. If save=True
, the protobuf (frozen graph) that corresponds to the testing image resolution will also be produced.
Use the provided pre-trained model
To produce output images and protobuf using the pre-trained model, use the following command:
python test_model.py orig=True
Use the self-obtained model
To produce output images and protobuf using the self-trained model, use the following command:
python test_model.py
Optional parameters (and default values):
dataset_dir
:raw_images/
- path to the folder with the dataset
test_dir
:fujifilm_full_resolution/
- path to the folder with the test data
model_dir
:models/
- path to the folder with the models to be restored/loaded
result_dir
:results/
- path to the folder with the produced outputs from the loaded model
arch
:punet
- architecture name
num_maps_base
:16
- base channel number (e.g. 8, 16, 32, etc.)
orig
:True
,False
- use the pre-trained model or not
restore_iter
:None
- iteration to restore (when not specified with self-train model, the last saved model will be loaded)
img_h
:1088
- width of the testing images
img_w
:1920
- height of the testing images
use_gpu
:True
,False
- run the model on GPU or CPU
save
:True
- save the loaded check point and protobuf (frozed graph) again
test_image
:True
- run the loaded model on the test images. Can set asFalse
if you only want to save models.
Below we provide an example command used for testing the model:
CUDA_VISIBLE_DEVICES=0 python test_model.py \
test_dir=fujifilm_full_resolution/ model_dir=models/punet_MAI/ result_dir=results/full-resolution/ \
arch=punet num_maps_base=16 orig=False restore_iter=98000 \
img_h=1088 img_w=1920 use_gpu=True save=True test_image=True
After inference, the output images will be produced under result_dir
.
[Optional] If save=True
, the following files will be produced under model_dir
:
[model_name].ckpt.meta
- graph data for the modelmodel_name
[model_name].pb
- protobuf (frozen graph) for the modelmodel_name
[model_name]/
- a folder containing Tensorboard data for the modelmodel_name
Notes:
- to export protobuf (frozen graph), the output node name needs to be specified. In this sample code, we use
output_l0
for PUNET. If you use a different name, please modify the argument for the functionutils.export_pb
(Line #111 intest_model.py
). You can also use Tensorboard to check the output node name. - In the Learned Smartphone ISP Challenge in Mobile AI (MAI) Workshop @ CVPR 2021, you may need to use different models for different evaluation goals (e.g. quality and latency). In this case, please specify different
img_h
andimg_w
for different evaluation goals.
- Example 1: Only produce the RGB images from validation data (resolution: 256x256) without saving the model:
CUDA_VISIBLE_DEVICES=0 python test_model.py \ test_dir=mediatek_raw/ model_dir=models/punet_MAI/ result_dir=results/ \ arch=punet num_maps_base=16 orig=False restore_iter=98000 \ img_h=256 img_w=256 use_gpu=True save=False test_image=True
- Example 2: Only produce the protobuf (e.g. resolution: 1088x1920) to evaluate the latency without generating any output images:
CUDA_VISIBLE_DEVICES=0 python test_model.py \ model_dir=models/punet_MAI/ \ arch=punet num_maps_base=16 orig=False restore_iter=98000 \ img_h=1088 img_w=1920 use_gpu=True save=True test_image=False
Convert checkpoint to pb
test_model.py
can produce protobuf automatically if save=True
.
If you want to directly convert the checkpoint model (including .meta
, .data
, and .index
) to protobuf, we also provide ckpt2pb.py
to do so.
The main arguments (and default values) are as follows:
--in_path
:models/punet_MAI/punet_iteration_100000.ckpt
- input checkpoint file (including.meta
,.data
, and.index
)
--out_path
:models/punet_MAI/punet_iteration_100000.pb
- output protobuf file
--out_nodes
:output_l0
- output node name
Notes:
- As mentioned earlier, the output node name needs to be specified. There are two ways to check the output node name:
- check the graph in Tensorboard.
- directly specify the node name in the source code (e.g. use
tf.identity
).
.meta
is necessary to convert a checkpoint to protobuf since it contains the important model information (e.g. architecture, input size, etc.).
Below we provide an example command:
python ckpt2pb.py \
--in_path models/punet_MAI/punet_iteration_100000.ckpt \
--out_path models/punet_MAI/punet_iteration_100000.pb \
--out_nodes output_l0
Convert pb to tflite
The last step is converting the frozen graph to TFLite so that the evaluation server can evaluate the performance on MediaTek devices. Please use the official Tensorflow function tflite_convert
. The main arguments (and default values) are as follows:
graph_def_file
:models/original/punet_pretrained.pb
- input protobuf file
output_file
:models/original/punet_pretrained.tflite
- output tflite file
input_shape
:1,544,960,4
- the network input, which is after debayering/demosaicing. If the raw image shape is(img_h, img_w, 1)
,input_shape
should be(img_h/2, img_w/2, 4)
.
input_arrays
:Placeholder
- input node name (can be found in Tensorboard, or specified in source codes)
output_arrays
:output_l0
- output node name (can be found in Tensorboard, or specified in source codes)
Below we provide an example command:
tflite_convert \
--graph_def_file=models/punet_MAI/punet_iteration_100000.pb \
--output_file=models/punet_MAI/punet_iteration_100000.tflite \
--input_shape=1,544,960,4 \
--input_arrays=Placeholder \
--output_arrays=output_l0
Feel free to use our provided bash script as well:
bash pb2tflite.sh
Note: pb2tflite.sh
converts our provided pretrained model, not exactly the same as the above example commands.
[Optional] Some useful tools
We also provide some useful tools. Feel free to try them.
Results
We evaluate the pre-trained PUNET on the validation data (resolution: 256x256), and obtain the following results:
- PSNR: 23.03
- Some visualized comparison with the ground truth:
Folder structure (default)
models/
- logs and models that are saved during the training process
models/original/
- the folder with the provided pre-trained PUNET model
raw_images/
- the folder with the dataset
results/
- visual results for the produced images
vgg-pretrained/
- the folder with the pre-trained VGG-19 network
tools/
- [optional] some useful tools
load_dataset.py
- python script that loads training data
model.py
- PUNET implementation (TensorFlow)
train_model.py
- implementation of the training procedure
test_model.py
- applying the trained model to testing images
utils.py
- auxiliary functions
vgg.py
- loading the pre-trained vgg-19 network
ckpt2pb.py
- convert checkpoint to protobuf (frozen graph)
pb2tflite.sh
- bash script that converts protobuf to tflite
Model Optimization
To make your model run faster on device, please fullfill the preference of network operations as much as possible to leverage the great power of AI accelerator. You may also find some optimization hint from our paper: Deploying Image Deblurring across Mobile Devices: A Perspective of Quality and Latency
Please find the download links to optmization guide in MAI'21 Learned Smartphone ISP Challenge website (registration needed).
Common FAQ
Acknowledge
This project is an extension of the PyNET project.
Citation
If you find this repository useful, please cite our MAI'21 work:
@misc{mtk2021mai,
title={Mobile AI Workshop: Learned Smartphone ISP Challenge},
year={2021},
url={https://github.com/MediaTek-NeuroPilot/mai21-learned-smartphone-isp}
}
License
PyNet License: PyNet License
Mediatek License: Mediatek Apache License 2.0
Contact
Please contact Min-Hung Chen for more information.
Email: mh.chen AT mediatek DOT com