This package implements the 1D,2D,3D Discrete Wavelet Transform and inverse DWT (IDWT) in Pytorch. The package was heavily inspired by pytorch_wavelets and extends its functionality into the third dimension.
The wavelets are provided by the PyWavelets package.
All operations in this package are fully differentiable.
git clone https://github.com/KeKsBoTer/torch-dwt
cd torch-dwt
pip install -e .
from torch_dwt.functional import dwt3,idwt3
import torch
# 8 images with 3 color channels and size of 100x100
x = torch.rand(8,3,100,100,100)
coefs = dwt3(x,"sym2") # coefs of shape (1,2,3,50)
# reconstruct signal from coefficients
y = idwt3(coefs,"sym2")
from torch_dwt.functional import dwt2,idwt2
import torch
# 8 images with 3 color channels and size of 100x100
x = torch.rand(8,3,100,100)
coefs = dwt2(x,"db2") # coefs of shape (1,2,3,50)
# reconstruct signal from coefficients
y = idwt2(coefs,"db2")
from torch_dwt.functional import dwt,idwt
import torch
# batch of size 8 with 3 channels
x = torch.rand(8,3,100)
coefs = dwt(x,"haar") # coefs of shape (1,2,3,50)
# reconstruct signal from coefficients
y = idwt(coefs,"haar")
For testing we compare our implementation againts PyWavelets. This command runs the tests:
# navigate into torch-dwt directory
pytest .