Detailed packing and calculation method of multi input channel and multi output channel
minnow54426 opened this issue · 10 comments
Hello, very appreciate for your elegant implementing of ResNet20 under CKKS using FHE.
When I am analyzing the code corresponding to the paper, which is convbn function in FHEController.cpp, I got into trouble.
In the function convbn_initial, there are 3 input channels and 16 output channels, the method of summing 3 input channels is that, using two rotations and additions, appended a mask function to delete useless data. Then different output channels is packed using right rotation and addition.
When it comes to convbn series functions, the packing and calculation method becomes confused to me. From the figure in the bottom of page 14 in the paper, i guess that, every row in the figure includes 9(3*3) plaintexts, so there are mainly two points hard to understand:
- how different input channel is summed?
- why the rotation in the end of this function is towards left, in the contrary of right?
Thanks very much.
I guess that, according to the figure in the bottom of page 14 of the paper, some packing tricks, such as diagonal packing is used, just like the packing tricks in fc layer, but i can not config the details.
Hello! If you check Figure 5, you will notice that, in reality, packing is simply in row-major order. So the channel is written in form of a vector row by row.
Then, all the channels are aligned in a single ciphertext.
In convbn,
For the rotation of the figure at the bottom of page 14, in realty it's always toward left, see line 790 of FHEController.cpp. After 16 rotations, it returns to the original ciphertext (16 * 1024 = 16384, which is the number of slots).
If you are familiar with Python, I suggest you to check section "Layer1[0]: Conv1+Bn1" of notebooks/Algorithm 2 - Exporting Weights.ipynb
Thansk very much!
Maybe there is something wrong with my word, I am confused about the packing method of filter element. In the multi input channel and multi output channel case, there are c_in * c_out filters(in resnet we got c_in = c_out nicely), and i am wondering how the c_in * c_out filters is arranged in 9 * c_in(I think the number is right) plaintext.
I will soon go to check the notebooks/Algorithm2 for layer1[0], thanks again.
I got it by reading the generation of k1 to k9 in notebooks.
When placing filters of different input and output channel in plaintext, some trick is used, then in following rotation and addition, the sum of different input channels in done automatically.
And in the c++ code, the slot number is set such that after 16 rotation, it returns to the original ciphertext.
I will close this issue, thanks!
Sorry for bothering you again.
In the end of conv function, the index passed to EvalRotate is a negative value, which indicates that the rotation is toward right according to the code explanation of openFHE, but the paper and our discuss show that the rotation should toward left, is there any thing wrong with my understanding?
Thanks!
Hello, thank you for the question!
Actually, it depends on how rotations of the kernel are generated, I think rotating to the right is the way to go, so consider the source code as the correct option
In the kernel generating code, which is layer[0]: conv1 + bn1, there are several steps(something wrong with latex preview, paste it to vscode for equations):
- The kernel is generated in the following form(use representation in paper, the first footnote represents output channel and the second footnote represents input channel):
$\begin{matrix}
k0,0 & k1,1 & \cdots & k15, 15 \
k0,1 & k1, 2 & \cdots & k15, 0 \
\vdots \
k0 ,15 & k1, 0 & \cdots & k15, 14
\end{matrix}$ - The generated kernel is masked and bn layer is combined
- Roll the kernel, change the matrix to
$\begin{matrix}
k0,0 & k1,1 & \cdots & k15, 15 \
k15,0 & k0, 1 & \cdots & k14, 15 \
\vdots \
k1, 0 & k2, 1 & \cdots & k0, 15
\end{matrix}$
I still can not understand why the rotation is toward right.
Thanks again!
I got it.
In the generation of kernel data, vectors are rotated towards right to align the image data.
Then the right rotation is used to sum different input channels.
Thanks!
I am happy that you managed to get it! Feel free to open new issues if you have any doubt :-)