How to select the uncertainty map topk small batch on the 3d heart volume data?
Closed this issue · 2 comments
zhuji423 commented
你好,非常高兴看见你们的工作被ijcai接受,我复现了你们的代码,再ISIC数据集上面可以达到你们的效果,甚至略微超过了0.1个点,证明你们的方法是可行有效的。
但我想在3d的ct数据上面实现UCMT,关键的难点在于unfold 函数pytorch官方只能实现4d的张量,对于ct数据而言,训练时的数据维度是五维的[batch,class,depth,width,height]。我无法实现umix的这一步,请问我应该使用哪一个函数来对3d的ct数据进行不确定度最高的块的选择呢。我在这一步卡了很长时间,非常感谢你们的回复。
unfolds = torch.nn.Unfold(kernel_size=(h, w), stride=s).to(device)
folds = torch.nn.Fold(output_size=(args.image_size, args.image_size), kernel_size=(h, w), stride=s).to(device)
x11 = unfolds(uncertainty_map11) # B x C*kernel_size[0]*kernel_size[1] x L 8 256 256
x11 = x11.view(B, 1, h, w, -1) # B x C x h x w x L
x11_mean = torch.mean(x11, dim=(1, 2, 3)) # B x L
_, x11_max_index = torch.sort(x11_mean, dim=1, descending=True) # B x L B x L
# for student 2
x22 = unfolds(uncertainty_map22) # B x C*kernel_size[0]*kernel_size[1] x L
x22 = x22.view(B, 1, h, w, -1) # B x C x h x w x L
x22_mean = torch.mean(x22, dim=(1, 2, 3)) # B x L
_, x22_max_index = torch.sort(x22_mean, dim=1, descending=True) # B x L B x L
img_unfold = unfolds(imageA1).view(B, C, h, w, -1) # B x C x h x w x L
lab_unfold = unfolds(label.float()).view(B, 1, h, w, -1) # B x C x h x w x L
for i in range(B):## 对8张图片进行操作
img_unfold[i, :, :, :, x11_max_index[i, :topk]] = img_unfold[i, :, :, :, x22_max_index[i, -topk:]]
img_unfold[i, :, :, :, x22_max_index[i, :topk]] = img_unfold[i, :, :, :, x11_max_index[i, -topk:]]
lab_unfold[i, :, :, :, x11_max_index[i, :topk]] = lab_unfold[i, :, :, :, x22_max_index[i, -topk:]]
lab_unfold[i, :, :, :, x22_max_index[i, :topk]] = lab_unfold[i, :, :, :, x11_max_index[i, -topk:]]
image2 = folds(img_unfold.view(B, C*h*w, -1))
label2 = folds(lab_unfold.view(B, 1*h*w, -1))
Senyh commented
感谢你对我们工作的关注。
你可以用unfoldNd库https://github.com/f-dangel/unfoldNd/tree/main
unfolds = unfoldNd.UnfoldNd(kernel_size=(d, h, w), stride=(d, h, w)).to(device)
folds = unfoldNd.FoldNd(output_size=(args.image_size[0], args.image_size[1], args.image_size[2]), kernel_size=(d, h, w), stride=(d, h, w)).to(device)
zhuji423 commented
非常感谢您的回复,我去尝试一下