Types of input/output in each discriminator
Closed this issue · 1 comments
Hi,
Thank you for sharing great works!!
I would like to use the pretrained discriminator with my scratch discriminator for improving my model. I added the discriminator of vision-aided-gan with cv_type is swin, vgg or clip. (My code is structured similar to edge-connect).
self.discr = vision_aided_loss.Discriminator(cv_type='swin', loss_type='sigmoid', device=config.DEVICE).to(config.DEVICE)
self.discr.cv_ensemble.requires_grad_(False)
When I input the generated images (BCH*W) and ground truth images into the discriminator, I got the following lossD from vgg and swin.
tensor([[1.3401],
[1.3370],
[1.2983],
[1.2942],
[1.1943],
[1.3307],
[1.2072],
[1.2092]], device='cuda:0', grad_fn=<AddBackward0>)
I could back propagate it by taking the average.
dis_loss = dis_real_loss + dis_fake_loss + torch.mean(lossD)
But, I got the following error from clip.
Traceback (most recent call last):
File "/home/naoki/MyProject/src/models.py", line 434, in process
lossD = self.discr(dis_input_real, for_real=True) + self.discr(dis_input_fake, for_real=False)
File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/vision_aided_loss/cv_discriminator.py", line 187, in forward
return self.loss_type(pred_mask, **kwargs)
File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/vision_aided_loss/cv_losses.py", line 104, in forward
loss_ = self.losses[i](input[i], **kwargs)
File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/vision_aided_loss/cv_losses.py", line 21, in forward
target_ = target.expand_as(input).to(input.device)
TypeError: expand_as(): argument 'other' (position 1) must be Tensor, not lis
I am not familiar with these pretrained models.
What are the input and output types for each discriminator?
Thank you in advance.
Hi, thanks for the interest in our code.
In case of CLIP network, default discriminator architecture architecture is multi-level, therefore loss_type should be multilevel_sigmoid_s
. If you want to use sigmoid_s
loss_type with single-level discriminator architecture, passing output_type as conv
in the arguments should enable that.
Let me know if this resolves the issue.