How to save the frequency domain image after FFT operation?
AbandonedWarlord opened this issue · 1 comments
` # 2D FFT
x_freq = torch.fft.fft2(image)
# shift low frequency to the center
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
# mask a portion of frequencies
x_freq_masked = x_freq
# restore the original frequency order
x_freq_masked = torch.fft.ifftshift(x_freq_masked, dim=(-2, -1))
# 2D iFFT (only keep the real part)
x_corrupted = torch.fft.ifft2(x_freq_masked).real
x_corrupted = torch.clamp(x_corrupted, min=0., max=1.)
x_np = x_corrupted.numpy()
im = Image.fromarray((x_np * 255).astype(np.uint8))
im.save(os.path.join(output_folder, filename))`
The image I save with the above code is far from the image of the cat example, can you provide a demo of this please?
You can use the script below:
def fft(x):
# x: Tensor, (B, 3, H, W), 0-1
# 2D FFT
x_freq = torch.fft.fft2(x)
# shift low frequency to the center
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
return x_freq
def show_image(image, save_path):
plt.imshow(image, cmap='viridis')
plt.axis('off')
plt.colorbar()
fig = plt.gcf()
fig.set_size_inches(image.shape[1] / 100, image.shape[0] / 100)
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
plt.savefig(save_path, dpi=100, bbox_inches='tight', pad_inches=0)
img = Image.open("path_to_image")
img = img.convert('L') # convert to L channel
img = T.ToTensor()(img)
fft_img = fft(img)
fft_img = torch.abs(fft_img) # magnitude
fft_img = torch.log1p(fft_img) # convert to log scale for better visualization
max_val, min_val = torch.max(fft_img), torch.min(fft_img)
fft_img = torch.div(fft_img - min_val, max_val - min_val) # to 0-1
fft_img = torch.einsum('chw->hwc', fft_img).numpy() # convert to channel last
show_image(fft_img, "save_path")
` # 2D FFT x_freq = torch.fft.fft2(image) # shift low frequency to the center x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) # mask a portion of frequencies x_freq_masked = x_freq # restore the original frequency order x_freq_masked = torch.fft.ifftshift(x_freq_masked, dim=(-2, -1)) # 2D iFFT (only keep the real part) x_corrupted = torch.fft.ifft2(x_freq_masked).real x_corrupted = torch.clamp(x_corrupted, min=0., max=1.) x_np = x_corrupted.numpy()
im = Image.fromarray((x_np * 255).astype(np.uint8)) im.save(os.path.join(output_folder, filename))`
The image I save with the above code is far from the image of the cat example, can you provide a demo of this please?