How to understand downsampling
Sinking-Stone opened this issue · 10 comments
Hello, I'm very sorry to bother you again. When I read your source code, I don't quite understand how your downsampling is calculated. Could you explain it to me, because it is just a simple description in your paper.
Essentially there are 3 phases, I take Ctxt FHEController::downsample1024to256(const Ctxt &c1, const Ctxt &c2) {
as an example:
(notice the line numbers in the image)
This is needed because some slots will be empty, because the previous convolution has some blank values (because of the stride equal to 2)
Can you explain line 346 what is in the main.cpp
vector<Ctxt> res1sx = controller.convbn1632sx(boot_in, 4, 1, scaleSx, timing); //Questo e lento
vector<Ctxt> res1dx = controller.convbn1632dx(boot_in, 4, 1, scaleDx, timing); //Questo e lento
What does this do? I don't understand it from this place. Can you explain what sx
and dx
are? How can I understand them better?
Oh, I have understood sx
and dx
, but I still can't connect the picture you gave me in the downsampling with the code. Could you give me an example? Thank you very much.
Ah ok.
The "re-arranging" figure is the procedure performed inside the Conv2D blocks, since they have a stride of {2, 2}. This means that the kernel window is shifted by 2 positions (and not by 1, as in previous convolutions), leaving one block empty in our ciphertexts (because the HE convolution assumes that the stride is equal to {1, 1}). For this reason, we have to re-arrange the values in order to fill the empty slots
Simulation of FHEController 1191-1194
def rot(nums,index):
index=index%len(nums)
if(index>0):
return nums[index:]+nums[:index]
else:
return nums[-index:]+nums[:-index]
def create_vec(n):
res=[]
for i in range(1,n+1):
line=[i]
res.append(line)
return res
def vec_add(vec1,vec2):
for i in range(len(vec1)):
if vec2[i]==[0]:
continue
elif vec1[i]==[0]:
vec1[i]=vec2[i]
else:
vec1[i].extend(vec2[i])
return vec1
def vec_mult(vec1,vec2):
for i in range(len(vec1)):
if vec2[i]==0:
vec1[i]=[0]
return vec1
def gen_mask(n):
mask=[]
ci=n
for i in range(1,32*32*32+1):
if ci>0 :
mask.append(1)
else:
mask.append(0)
ci-=1
if ci<=-n:
ci=n
return mask
if __name__ == '__main__':
with open('test.txt','w') as f:
fullpacke=create_vec(32*32*32)
fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,1)),gen_mask(2))
f.write(str(fullpacke))
f.write('\n------------------------------------------------------------------------------------------\n')
fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,2)),gen_mask(4))
f.write(str(fullpacke))
f.write('\n------------------------------------------------------------------------------------------\n')
fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,4)),gen_mask(8))
f.write(str(fullpacke))
f.write('\n------------------------------------------------------------------------------------------\n')
fullpacke=vec_add(fullpacke,rot(fullpacke,8))
f.write(str(fullpacke))
I am very sorry that I have been simulating your program, the numbers in my code represent subscripts, maybe there is something wrong with my writing, I am different from the diagram you gave, could you give me an example about the subsampling code? Thank you very much.
Ah ok. The "re-arranging" figure is the procedure performed inside the Conv2D blocks, since they have a stride of {2, 2}. This means that the kernel window is shifted by 2 positions (and not by 1, as in previous convolutions), leaving one block empty in our ciphertexts (because the HE convolution assumes that the stride is equal to {1, 1}). For this reason, we have to re-arrange the values in order to fill the empty slots
how do you select the rotation or mask index in the downsample function? about the index 1,2,8? Are these numbers related to the stride or downsampe size and others?
nums = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 128, 130, 132, 134, 136, 138, 140, 142, 144, 146, 148, 150, 152, 154, 156, 158, 192, 194, 196, 198, 200, 202, 204, 206, 208, 210, 212, 214, 216, 218, 220, 222, 256, 258, 260, 262, 264, 266, 268, 270, 272, 274, 276, 278, 280, 282, 284, 286, 320, 322, 324, 326, 328, 330, 332, 334, 336, 338, 340, 342, 344, 346, 348, 350, 384, 386, 388, 390, 392, 394, 396, 398, 400, 402, 404, 406, 408, 410, 412, 414, 448, 450, 452, 454, 456, 458, 460, 462, 464, 466, 468, 470, 472, 474, 476, 478, 512, 514, 516, 518, 520, 522, 524, 526, 528, 530, 532, 534, 536, 538, 540, 542, 576, 578, 580, 582, 584, 586, 588, 590, 592, 594, 596, 598, 600, 602, 604, 606, 640, 642, 644, 646, 648, 650, 652, 654, 656, 658, 660, 662, 664, 666, 668, 670, 704, 706, 708, 710, 712, 714, 716, 718, 720, 722, 724, 726, 728, 730, 732, 734, 768, 770, 772, 774, 776, 778, 780, 782, 784, 786, 788, 790, 792, 794, 796, 798, 832, 834, 836, 838, 840, 842, 844, 846, 848, 850, 852, 854, 856, 858, 860, 862, 896, 898, 900, 902, 904, 906, 908, 910, 912, 914, 916, 918, 920, 922, 924, 926, 960, 962, 964, 966, 968, 970, 972, 974, 976, 978, 980, 982, 984, 986, 988, 990]
ori = fill_with_zeros(nums)
# print(output)
m = list(ori)
if len(m) < 1024:
for i in range(1024-len(m)):
m.append(0)
ori = np.array(m)
# print(ori.shape)
shift = left_rotate(list(ori), 1)
# print(shift.shape)
# print(ori[:64])
# print(shift[:64])
merge = np.array(ori) + np.array(shift)
# print(merge[:64])
def genmask(n,length):
c = n
lists = []
for i in range(length):
if c >0:
lists.append(1)
else:
lists.append(0)
c -= 1
if c <= -n:
c = n
return np.array(lists)
mask = genmask(2, len(merge))
# print(mask)
fullpack_step1 = merge * mask
# print("fullpack_step1" ,fullpack_step1[:64])
# step2
shift = left_rotate(list(fullpack_step1), 2)
# print(shift[:64])
mask = genmask(4, len(shift))
merge = shift + fullpack_step1
fullpack_step2 = mask * merge
# print(fullpack_step2[:64])
# # step3
shift = left_rotate(list(fullpack_step2), 4)
merge = shift + fullpack_step2
mask = genmask(8, len(fullpack_step2))
fullpack_step3 = mask * merge
# print(fullpack_step3[:64])
# print(left_rotate(list(fullpack_step3),8)[:64])
finallOut = fullpack_step3 + left_rotate(list(fullpack_step3),8)
print(finallOut[:128])
# print(finallOut)
def mask_first_n_mod(n, padding, pos):
mask = []
for m in range(1):
for j in range(pos*n):
mask.append(0)
for j in range(n):
mask.append(1)
for j in range((padding-n-(pos*n))):
mask.append(0)
return np.array(mask)
orix = finallOut
print(orix[:32])
downsample = np.zeros(orix.shape[0],dtype=np.uint32)
for i in range(16):
mask = mask_first_n_mod(16, 1024, i)
# print(mask[:32])
maksed = mask * orix
downsample = downsample + maksed
# print(i, downsample[:64])
# print(i)
if i < 15:
orix = left_rotate(list(orix), 64-16)
# print(ori[:32])
print(downsample)
Simulation of FHEController 1191-1194
def rot(nums,index): index=index%len(nums) if(index>0): return nums[index:]+nums[:index] else: return nums[-index:]+nums[:-index] def create_vec(n): res=[] for i in range(1,n+1): line=[i] res.append(line) return res def vec_add(vec1,vec2): for i in range(len(vec1)): if vec2[i]==[0]: continue elif vec1[i]==[0]: vec1[i]=vec2[i] else: vec1[i].extend(vec2[i]) return vec1 def vec_mult(vec1,vec2): for i in range(len(vec1)): if vec2[i]==0: vec1[i]=[0] return vec1 def gen_mask(n): mask=[] ci=n for i in range(1,32*32*32+1): if ci>0 : mask.append(1) else: mask.append(0) ci-=1 if ci<=-n: ci=n return mask if __name__ == '__main__': with open('test.txt','w') as f: fullpacke=create_vec(32*32*32) fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,1)),gen_mask(2)) f.write(str(fullpacke)) f.write('\n------------------------------------------------------------------------------------------\n') fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,2)),gen_mask(4)) f.write(str(fullpacke)) f.write('\n------------------------------------------------------------------------------------------\n') fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,4)),gen_mask(8)) f.write(str(fullpacke)) f.write('\n------------------------------------------------------------------------------------------\n') fullpacke=vec_add(fullpacke,rot(fullpacke,8)) f.write(str(fullpacke))I am very sorry that I have been simulating your program, the numbers in my code represent subscripts, maybe there is something wrong with my writing, I am different from the diagram you gave, could you give me an example about the subsampling code? Thank you very much.