您好,我想请教一下关于代码的问题
cccusername opened this issue · 4 comments
您好,在main.py中的gen_mask()中,首先从segmentation()中得到seg_result。我想请教一下在gen_mask中对seg_result是进行了一个什么处理,目的是为什么呢。看不太懂这一块,所以来请教一下,谢谢
seg_result 的结果是150类的分割结果. DeepPhotoStyle_pytorch/seg/objectInfo150.txt
程序将150类剔除了一些类别, 还合并了一些类别(形成了9大子类). DeepPhotoStyle_pytorch/merge_index.py
def gen_mask(image_path):
seg_result = segmentation(image_path).squeeze(0)
channel, height_, width_ = seg_result.size()
for classes in merge_classes:
for index, each_class in enumerate(classes):
if index == 0:
zeros_index = each_class
base_map = seg_result[each_class, :, :].clone()
else:
base_map = base_map | seg_result[each_class, :, :]
seg_result[zeros_index, :, :] = base_map
return seg_result, height_, width_
上述程序中的classes 指的是9大类中的某个子类,
base_map = base_map | seg_result[each_class, :, :]
这一句说的是, 将同属于某个子类的语义分割图进行合并(或运算).
@cccusername 希望我说明白了 😃
那后面的merged_style_mask表示的是什么呢,为什么他的第一维默认设为117
merged_style_mask = np.zeros((117, height_, width_), dtype='int')
del_classed = [26, 60, 128, 9, 17, 32, 1, 25, 48, 79, 84, 11, 13, 29, 46, 51, 52, 68, 91, 94, 101,
86, 34, 59, 121, 23, 30, 31, 69, 75, 80, 83, 102] 删除的类别是33类 150-3=117 .
代码写得太随意了 haha||
merged_style_mask 是对 gen_mask产生的结果做后处理, 例如去掉 面积小于50的 mask之类的.
def gen_mask(image_path):
seg_result = segmentation(image_path).squeeze(0)
channel, height_, width_ = seg_result.size()
for classes in merge_classes:
for index, each_class in enumerate(classes):
if index == 0:
zeros_index = each_class
base_map = seg_result[each_class, :, :].clone()
else:
base_map = base_map | seg_result[each_class, :, :]
@@@@@@@@@@@@@@@@@@@@
seg_result[each_class, :, :] = 0.0
@@@@@@@@@@@@@@@@@@@@
seg_result[zeros_index, :, :] = base_map
return seg_result, height_, width_
似乎应该加一句
@@@@@@@@@@@@@@@@@@@@
seg_result[each_class, :, :] = 0.0
@@@@@@@@@@@@@@@@@@@@
不好意思啊 不太记得了~~~
好的,谢谢您!!!