ray075hl/DeepPhotoStyle_pytorch

您好,我想请教一下关于代码的问题

cccusername opened this issue · 4 comments

您好,在main.py中的gen_mask()中,首先从segmentation()中得到seg_result。我想请教一下在gen_mask中对seg_result是进行了一个什么处理,目的是为什么呢。看不太懂这一块,所以来请教一下,谢谢

seg_result 的结果是150类的分割结果. DeepPhotoStyle_pytorch/seg/objectInfo150.txt
程序将150类剔除了一些类别, 还合并了一些类别(形成了9大子类). DeepPhotoStyle_pytorch/merge_index.py

def gen_mask(image_path):
    seg_result = segmentation(image_path).squeeze(0)
    channel, height_, width_ = seg_result.size()

    for classes in merge_classes:
        for index, each_class in enumerate(classes):
            if index == 0:
                zeros_index = each_class
                base_map = seg_result[each_class, :, :].clone()
            else:
                base_map = base_map | seg_result[each_class, :, :]
        seg_result[zeros_index, :, :] = base_map

return seg_result, height_, width_

上述程序中的classes 指的是9大类中的某个子类,

base_map = base_map | seg_result[each_class, :, :] 

这一句说的是, 将同属于某个子类的语义分割图进行合并(或运算).
@cccusername 希望我说明白了 😃

那后面的merged_style_mask表示的是什么呢,为什么他的第一维默认设为117
merged_style_mask = np.zeros((117, height_, width_), dtype='int')

del_classed = [26, 60, 128, 9, 17, 32, 1, 25, 48, 79, 84, 11, 13, 29, 46, 51, 52, 68, 91, 94, 101,
86, 34, 59, 121, 23, 30, 31, 69, 75, 80, 83, 102] 删除的类别是33类 150-3=117 .
代码写得太随意了 haha||

merged_style_mask 是对 gen_mask产生的结果做后处理, 例如去掉 面积小于50的 mask之类的.

def gen_mask(image_path):
    seg_result = segmentation(image_path).squeeze(0)
    channel, height_, width_ = seg_result.size()

    for classes in merge_classes:
        for index, each_class in enumerate(classes):
            if index == 0:
                zeros_index = each_class
                base_map = seg_result[each_class, :, :].clone()
            else:
                base_map = base_map | seg_result[each_class, :, :]
            
            @@@@@@@@@@@@@@@@@@@@
            seg_result[each_class, :, :] = 0.0 
            @@@@@@@@@@@@@@@@@@@@

        seg_result[zeros_index, :, :] = base_map

return seg_result, height_, width_

似乎应该加一句
@@@@@@@@@@@@@@@@@@@@
seg_result[each_class, :, :] = 0.0
@@@@@@@@@@@@@@@@@@@@
不好意思啊 不太记得了~~~

好的,谢谢您!!!