A package that provides a PyTorch C extension for performing batches of 2D CuFFT transformations, by Eric Wong
This package is on PyPi. Install with pip install pytorch-fft
.
- From the
pytorch_fft.fft
module, you can use the following to do foward and backward FFT transformations (complex to complex)fft
andifft
for 1D transformationsfft2
andifft2
for 2D transformationsfft3
andifft3
for 3D transformations
- From the same module, you can also use the following for
real to complex / complex to real FFT transformations
rfft
andirfft
for 1D transformationsrfft2
andirfft2
for 2D transformationsrfft3
andirfft3
for 3D transformations
- For an
d
-D transformation, the input tensors are required to have >= (d+1) dimensions (n1 x ... x nk x m1 x ... x md) wheren1 x ... x nk
is the batch of FFT transformations, andm1 x ... x md
are the dimensions of thed
-D transformation.d
must be a number from 1 to 3. - Finally, the module contains the following helper functions you may find
useful
reverse(X, group_size=1)
reverses the elements of a tensor and returns the result in a new tensor. Note that PyTorch does not current support negative slicing, see this issue. If a group size is supplied, the elements will be reversed in groups of that size.expand(X, imag=False, odd=True)
takes a tensor output of a real 2D or 3D FFT and expands it with its redundant entries to match the output of a complex FFT.
# Example that does a batch of three 2D transformations of size 4 by 5.
import torch
import pytorch_fft.fft as fft
A, zeros = torch.randn(3,4,5).cuda(), torch.zeros(3,4,5).cuda()
B_real, B_imag = fft.fft2(A_real, A_imag)
fft.ifft2(B_real, B_imag) # equals (A, zeros)
B_real, B_imag = fft.rfft2(A) # is a truncated version which omits
# redundant entries
reverse(torch.arange(0,6)) # outputs [5,4,3,2,1,0]
reverse(torch.arange(0,6), 2) # outputs [4,5,2,3,0,1]
expand(B_real) # is equivalent to fft.fft2(A, zeros)[0]
expand(B_imag, imag=True) # is equivalent to fft.fft2(A, zeros)[1]
- This follows NumPy semantics and behavior, so
ifft2(fft2(x)) = x
. Note that CuFFT semantics for inverse FFT only flip the sign of the transform, but it is not a true inverse. - Similarly, the real to complex / complex to real variants also follow NumPy
semantics and behavior. In the 1D case, this means that for an input of size
N
, it returns an output of sizeN//2+1
(it omits redundant entries, see the Numpy docs) - The functions in the
pytorch_fft.fft
module do not implement the PyTorch autogradFunction
, and are semantically and functionally like their numpy equivalents. - Autograd functionality is currently experimental in the
autograd
module and is currently untested.
- pytorch_fft/src: C source code
- pytorch_fft/fft: Python convenience wrapper
- build.py: compilation file
- test.py: tests against NumPy FFTs
If you have any issues or feature requests, file an issue or send in a PR.