lustrouselixir/FGDNet

question about the network for flash guided non-flash denoising task

Opened this issue · 1 comments

The inputs for RGB-NIR task are 'noisy_img'(h,w,3), 'guidance_img'(h,w,1) and you split the RGB noisy input images following:

    with torch.no_grad():
        denoised_r, _, _ = model(noisy_img[:,0,:,:,].unsqueeze(0), guidance_img)
        denoised_g, _, _ = model(noisy_img[:,1,:,:,].unsqueeze(0), guidance_img)
        denoised_b, _, _ = model(noisy_img[:,2,:,:,].unsqueeze(0), guidance_img)
    denoised_img = torch.cat([denoised_r, denoised_g, denoised_b], dim=1)

But in flash guided non-flash denoising task, both input images have 3 channels:
'noisy_img'(h,w,3), 'guidance_img'(h,w,3)

Could you tell me what modification should be done in 'FGDNet.py' if I want to train the model for flash guided non-flash denoising task?
Or can I just keep 'FGDNet.py' the same and split the inputs following:

    with torch.no_grad():
        denoised_r, _, _ = model(noisy_img[:,0,:,:,].unsqueeze(0), guidance_img[:,0,:,:,].unsqueeze(0))
        denoised_g, _, _ = model(noisy_img[:,1,:,:,].unsqueeze(0), guidance_img[:,1,:,:,].unsqueeze(0))
        denoised_b, _, _ = model(noisy_img[:,2,:,:,].unsqueeze(0), guidance_img[:,2,:,:,].unsqueeze(0))
    denoised_img = torch.cat([denoised_r, denoised_g, denoised_b], dim=1)

Thanks a lot!

@JingyiXu404 Hi! Thanks for your attention. FGDNet processes one channel at a time for both target and guidance images, as in your second codes.