Christoph Reich
, Biplob Debnath
, Deep Patel
& Srimat Chakradhar
| Project Page | Paper | Poster | Talk |
This repository includes the official and maintained implementation of the differentiable JPEG approach proposed in the paper Differentiable JPEG: The Devil is in the Details.
JPEG remains one of the most widespread lossy image coding methods. However, the non-differentiable nature of JPEG restricts the application in deep learning pipelines. Several differentiable approximations of JPEG have recently been proposed to address this issue. This paper conducts a comprehensive review of existing differentiable JPEG approaches and identifies critical details that have been missed by previous methods. To this end, we propose a novel differentiable JPEG approach, overcoming previous limitations. Our approach is differentiable w.r.t. the input image, the JPEG quality, the quantization tables, and the color conversion parameters. We evaluate the forward and backward performance of our differentiable JPEG approach against existing methods. Additionally, extensive ablations are performed to evaluate crucial design choices. Our proposed differentiable JPEG resembles the (non-differentiable) reference implementation best, significantly surpassing the recent-best differentiable approach by 3.47dB (PSNR) on average. For strong compression rates, we can even improve PSNR by 9.51dB. Strong adversarial attack results are yielded by our differentiable JPEG, demonstrating the effective gradient approximation.
If you use our differentiable JPEG or find this research useful in your work, please cite our paper:
@inproceedings{Reich2024,
author={Reich, Christoph and Debnath, Biplob and Patel, Deep and Chakradhar, Srimat},
title={{Differentiable JPEG: The Devil is in the Details}},
booktitle={{WACV}},
year={2024}
}
Our differentiable JPEG implementation can be installed as a Python package by running:
pip install git+https://github.com/necla-ml/Diff-JPEG
All dependencies are listed in requirements.txt.
We offer both a functional and class (nn.Module) implementation of our differentiable JPEG approach. Note beyond the examples provided here we also have an example.py file.
The following example showcases the use of the functional implementation.
import torch
import torchvision
from torch import Tensor
from diff_jpeg import diff_jpeg_coding
# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([2.0])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding(image_rgb=image, jpeg_quality=jpeg_quality)
In the following code example, the class (nn.Module) implementation is used.
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from diff_jpeg import DiffJPEGCoding
# Init module
diff_jpeg_coding_module: nn.Module = DiffJPEGCoding()
# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([19.04])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding_module(image_rgb=image, jpeg_quality=jpeg_quality)
To utilize the proposed straight-through estimator (STE) variant just set the ste: bool = True
parameter.
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding(image_rgb=image, jpeg_quality=jpeg_quality, ste=True)
# Init module
diff_jpeg_coding_module: nn.Module = DiffJPEGCoding(ste=True)
Both the diff_jpeg_coding
function and the forward function of DiffJPEGCoding
offer the option to use custom
quantization tables. Just use the quantization_table_y: Optional[Tensor]
and quantization_table_c: Optional[Tensor]
parameter. Both parameters are required to be a torch.Tensor
of the shape [8, 8]
. If no quantization table is
given (or set to None
), the respective standard JPEG quantization tables are utilized.
Here we provide two examples of using a custom quantization table.
import torch
import torchvision
from torch import Tensor
from diff_jpeg import diff_jpeg_coding
# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([2.0])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding(
image_rgb=image,
jpeg_quality=jpeg_quality,
quantization_table_y=torch.randint(low=1, high=256, size=(8, 8)),
quantization_table_c=torch.randint(low=1, high=256, size=(8, 8)),
)
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from diff_jpeg import DiffJPEGCoding
# Init module
diff_jpeg_coding_module: nn.Module = DiffJPEGCoding()
# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([19.04])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding_module(
image_rgb=image,
jpeg_quality=jpeg_quality,
quantization_table_y=torch.randint(low=1, high=256, size=(8, 8)),
quantization_table_c=torch.randint(low=1, high=256, size=(8, 8)),
)
If you encounter any issues with this implementation please open a GitHub issue!