WCT with Mask
Closed this issue · 6 comments
Hi Xueting Li, many thanks for sharing your great implementation. I am wondering whether you have provided the Pytorch code of testing WCT with mask similar as this code in the original repository.
Sorry, I didn't include that part. It should be straight forward to implement. I don't have time to organize and push that code to this repo. But I can provide the core code for your reference. The main idea is to modify the transform function to the feature_wct
function below such that it only takes features within a mask.
def large_dff(a,b):
if(a / b >= 100):
return True
if(b / a >= 100):
return True
return False
def scale_dialate(seg,W,H):
# TODO: dialate
seg = seg.view(1, 1, seg.shape[0], seg.shape[1])
seg = F.interpolate(seg, size = (H, W), mode = 'nearest')
return seg.squeeze()
def feature_wct(self, cF, sF, cmasks, smasks, alpha):
color_code_number = 1
C, W, H = cF.size(0),cF.size(1),cF.size(2)
_, W1, H1 = sF.size(0),sF.size(1),sF.size(2)
targetFeature = cF.view(C,-1).clone()
for i in range(color_code_number):
cmask = cmasks[i].clone().squeeze(0)
smask = smasks[i].clone().squeeze(0)
if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10 and not large_dff(torch.sum(cmask),torch.sum(smask))):
cmaskResized = scale_dialate(cmask,W,H)
if(torch.max(cmaskResized) <= 0):
print('cmaskResized has no 1, ignore')
continue
cmaskView = cmaskResized.view(-1)
fgcmask = (cmaskView == 1).nonzero().squeeze(1)
fgcmask = fgcmask.cuda()
cFView = cF.view(C,-1)
cFFG = torch.index_select(cFView,1,fgcmask.long())
smaskResized = scale_dialate(smask,W1,H1)
if(torch.max(smaskResized) <= 0):
print('smaskResized has no 1, ignore')
continue
smaskView = smaskResized.view(-1)
fgsmask = (smaskView == 1).nonzero().squeeze(1)
fgsmask = fgsmask.cuda()
sFView = sF.view(C,-1)
sFFG = torch.index_select(sFView,1,fgsmask.long())
targetFeatureFG = wct2(cFFG,sFFG)
targetFeature.index_copy_(1,fgcmask,targetFeatureFG)
del fgcmask
del fgsmask
targetFeature = targetFeature.view_as(cF)
ccsF = alpha * targetFeature + (1.0 - alpha) * cF
ccsF = ccsF.float().unsqueeze(0)
return ccsF
Many thanks for sharing these functions, it was very helpful. I just have one question on it.
What is wct2
in this line: targetFeatureFG = wct2(cFFG,sFFG)
? since it was not defined in util.py. Thanks!
Many thanks for sharing these functions, it was very helpful. I just have one question on it.
What iswct2
in this line:targetFeatureFG = wct2(cFFG,sFFG)
? since it was not defined in util.py. Thanks!
Here it is, it's the feature transform function:
def wct2(cF,sF):
cFSize = cF.size()
c_mean = torch.mean(cF,1) # c x (h x w)
c_mean = c_mean.unsqueeze(1).expand_as(cF)
cF = cF - c_mean
iden = torch.eye(cFSize[0]).cuda().double()
contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + iden
c_u,c_e,c_v = torch.svd(contentConv,some=False)
k_c = cFSize[0]
for i in range(cFSize[0]):
if c_e[i] < 0.00001:
k_c = i
break
sFSize = sF.size()
s_mean = torch.mean(sF,1)
sF = sF - s_mean.unsqueeze(1).expand_as(sF)
#iden = torch.eye(sFSize[0]).cuda().double()
try:
styleConv = torch.mm(sF,sF.t()).div(max(1, sFSize[1]-1)) + iden
except:
import pdb; pdb.set_trace()
s_u,s_e,s_v = torch.svd(styleConv,some=False)
k_s = sFSize[0]
for i in range(sFSize[0]):
if s_e[i] < 0.00001:
k_s = i
break
#c_d = (c_e[0:k_c]).pow(-0.5)
c_d = (c_e).pow(-0.5)
step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
whiten_cF = torch.mm(step2,cF)
s_d = (s_e[0:k_s]).pow(0.5)
diag_matrix = torch.diag(s_d)
targetFeature = torch.mm(torch.mm(torch.mm(s_v[:,0:k_s],diag_matrix),(s_v[:,0:k_s].t())),whiten_cF)
targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
return targetFeature
Thanks a lot for sharing the useful code. I have updated the code in this forked repository by adding WCT_mask.py, loader_mask.py and util_mask.py.
I made two binary masks for content and style and tried to run the code by this command: python WCT_mask.py --cuda --contentPath /content/img --stylePath /style/img --content_mask_Path /content/mask --style_mask_Path /style/mask
, but I got the following error at line 171 of util_mask.py:
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
The problem is because the output of (cmaskView == 1).nonzero().shape
is: torch.Size([0])
. I'm wondering if there is a problem in styleTransfer implementation or in transform_mask.
The full error command is as below:
Traceback (most recent call last):
File "WCT_mask.py", line 154, in <module>
styleTransfer(cImg,sImg,cMImg,sMImg,imname,csF)
File "WCT_mask.py", line 64, in styleTransfer
csF5 = wct.transform_mask(cF5,sF5,cmF5,smF5,args.alpha) #(cF5,sF5,csF,args.alpha)
File "/projects/WCT_Pytorch/util_mask.py", line 171, in transform_mask
fgcmask = (cmaskView == 1).nonzero().squeeze(1)
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
I appreciate a lot all your provided help. Thanks!
I could solve the previous issue by changing this line:
fgcmask = (cmaskView == 1).nonzero().squeeze(1)
to:
fgcmask = (cmaskView > 0 ).nonzero().squeeze(1)
and by some other small changes I could get this image as the result:
I used this image as the content:
This is the mask for the content:
This is for the style:
And this is for the mask of the style:
It works well for capturing a part of the style image to transfer to the whole of the content image, but I am still wondering how can I use it to transfer the style to the part of the content image and maintain the rest of the content images as it is.
I updated the code in the forked repository .
Thanks a lot for all your great help.
I could solve the last problem and everything works fine. I got the image below with the content mask and style mask.
I updated the code in the forked repository.
Thanks again for your useful help.