mask in test_in_any_resolution.py
IDKiro opened this issue · 0 comments
IDKiro commented
Maybe you can padding the image first, otherwise invalid information will be fused in the large scale token:
def expand2square(timg, factor=128):
# padding first
_, _, h, w = timg.size()
mod_pad_h = (factor - h % factor) % factor
mod_pad_w = (factor - w % factor) % factor
timg = F.pad(timg, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
_, _, h, w = timg.size()
X = int(math.ceil(max(h,w)/float(factor))*factor)
img = torch.zeros(1,3,X,X).type_as(timg) # 3,h,w
mask = torch.ones(1,1,X,X).type_as(timg) # for -inf
img[:, :, :h,:w] = timg
mask[:, :, :h,:w].fill_(0.0)
return img, mask