sunshineatnoon/PytorchWCT

WCT with Mask

Closed this issue · 6 comments

Hi Xueting Li, many thanks for sharing your great implementation. I am wondering whether you have provided the Pytorch code of testing WCT with mask similar as this code in the original repository.

Sorry, I didn't include that part. It should be straight forward to implement. I don't have time to organize and push that code to this repo. But I can provide the core code for your reference. The main idea is to modify the transform function to the feature_wct function below such that it only takes features within a mask.

def large_dff(a,b):
    if(a / b >= 100):
        return True
    if(b / a >= 100):
        return True
    return False
 
def scale_dialate(seg,W,H):
    # TODO: dialate
    seg = seg.view(1, 1, seg.shape[0], seg.shape[1])
    seg = F.interpolate(seg, size = (H, W), mode = 'nearest')
    return seg.squeeze()

def feature_wct(self, cF, sF, cmasks, smasks, alpha):
        color_code_number = 1

        C, W, H = cF.size(0),cF.size(1),cF.size(2)
        _, W1, H1 = sF.size(0),sF.size(1),sF.size(2)
        targetFeature = cF.view(C,-1).clone()
        for i in range(color_code_number):
            cmask = cmasks[i].clone().squeeze(0)
            smask = smasks[i].clone().squeeze(0)
            if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10 and not large_dff(torch.sum(cmask),torch.sum(smask))):
                cmaskResized = scale_dialate(cmask,W,H)
                if(torch.max(cmaskResized) <= 0):
                    print('cmaskResized has no 1, ignore')
                    continue
                cmaskView = cmaskResized.view(-1)
                fgcmask = (cmaskView == 1).nonzero().squeeze(1)
                fgcmask = fgcmask.cuda()
                cFView = cF.view(C,-1)
                cFFG = torch.index_select(cFView,1,fgcmask.long())

                smaskResized = scale_dialate(smask,W1,H1)
                if(torch.max(smaskResized) <= 0):
                    print('smaskResized has no 1, ignore')
                    continue
                smaskView = smaskResized.view(-1)
                fgsmask = (smaskView == 1).nonzero().squeeze(1)
                fgsmask = fgsmask.cuda()
                sFView = sF.view(C,-1)
                sFFG = torch.index_select(sFView,1,fgsmask.long())

                targetFeatureFG = wct2(cFFG,sFFG)
                targetFeature.index_copy_(1,fgcmask,targetFeatureFG)
                del fgcmask
                del fgsmask
        targetFeature = targetFeature.view_as(cF)
        ccsF = alpha * targetFeature + (1.0 - alpha) * cF
        ccsF = ccsF.float().unsqueeze(0)
        return ccsF

Many thanks for sharing these functions, it was very helpful. I just have one question on it.
What is wct2 in this line: targetFeatureFG = wct2(cFFG,sFFG)? since it was not defined in util.py. Thanks!

Many thanks for sharing these functions, it was very helpful. I just have one question on it.
What is wct2 in this line: targetFeatureFG = wct2(cFFG,sFFG)? since it was not defined in util.py. Thanks!

Here it is, it's the feature transform function:

def wct2(cF,sF):
    cFSize = cF.size()
    c_mean = torch.mean(cF,1) # c x (h x w)
    c_mean = c_mean.unsqueeze(1).expand_as(cF)
    cF = cF - c_mean

    iden = torch.eye(cFSize[0]).cuda().double()
    contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + iden
    c_u,c_e,c_v = torch.svd(contentConv,some=False)

    k_c = cFSize[0]
    for i in range(cFSize[0]):
        if c_e[i] < 0.00001:
            k_c = i
            break

    sFSize = sF.size()
    s_mean = torch.mean(sF,1)
    sF = sF - s_mean.unsqueeze(1).expand_as(sF)
    #iden = torch.eye(sFSize[0]).cuda().double()
    try:
        styleConv = torch.mm(sF,sF.t()).div(max(1, sFSize[1]-1)) + iden
    except:
        import pdb; pdb.set_trace()
    s_u,s_e,s_v = torch.svd(styleConv,some=False)

    k_s = sFSize[0]
    for i in range(sFSize[0]):
        if s_e[i] < 0.00001:
            k_s = i
            break

    #c_d = (c_e[0:k_c]).pow(-0.5)
    c_d = (c_e).pow(-0.5)
    step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
    step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
    whiten_cF = torch.mm(step2,cF)

    s_d = (s_e[0:k_s]).pow(0.5)
    diag_matrix = torch.diag(s_d)
    targetFeature = torch.mm(torch.mm(torch.mm(s_v[:,0:k_s],diag_matrix),(s_v[:,0:k_s].t())),whiten_cF)
    targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
    return targetFeature

Thanks a lot for sharing the useful code. I have updated the code in this forked repository by adding WCT_mask.py, loader_mask.py and util_mask.py.
I made two binary masks for content and style and tried to run the code by this command: python WCT_mask.py --cuda --contentPath /content/img --stylePath /style/img --content_mask_Path /content/mask --style_mask_Path /style/mask, but I got the following error at line 171 of util_mask.py:
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
The problem is because the output of (cmaskView == 1).nonzero().shape is: torch.Size([0]). I'm wondering if there is a problem in styleTransfer implementation or in transform_mask.

The full error command is as below:

Traceback (most recent call last):
  File "WCT_mask.py", line 154, in <module>
    styleTransfer(cImg,sImg,cMImg,sMImg,imname,csF)
  File "WCT_mask.py", line 64, in styleTransfer
    csF5 = wct.transform_mask(cF5,sF5,cmF5,smF5,args.alpha)   #(cF5,sF5,csF,args.alpha)
  File "/projects/WCT_Pytorch/util_mask.py", line 171, in transform_mask
    fgcmask = (cmaskView == 1).nonzero().squeeze(1)
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

I appreciate a lot all your provided help. Thanks!

I could solve the previous issue by changing this line:
fgcmask = (cmaskView == 1).nonzero().squeeze(1)
to:
fgcmask = (cmaskView > 0 ).nonzero().squeeze(1)
and by some other small changes I could get this image as the result:
im_1
I used this image as the content:
im_1
This is the mask for the content:
im_1
This is for the style:
im_1
And this is for the mask of the style:
im_1
It works well for capturing a part of the style image to transfer to the whole of the content image, but I am still wondering how can I use it to transfer the style to the part of the content image and maintain the rest of the content images as it is.
I updated the code in the forked repository .
Thanks a lot for all your great help.

I could solve the last problem and everything works fine. I got the image below with the content mask and style mask.
im_1
I updated the code in the forked repository.
Thanks again for your useful help.