Question about the split model
Closed this issue · 1 comments
I am a beginner in deep learning and have read the code carefully. But there are still some questions I would like to ask.
In the backpropagation phase of the vertical federated framework, the bottom model needs to calculate a loss. And the loss is
The note states
read grad of : input of top model(also output of bottom models), which will be used as bottom model's target
I didn't understand. Why the bottom model's loss is calculated in this way?
It is because of the chain rule. The loss function of the bottom model is defined like this (output_tensor_bottom_model * grad_output_bottom_model) to make sure that the gradients of parameters in the bottom model will be calculated normally, i.e., as the gradients of the top model's loss w.r.t. the parameters of the bottom model.
You can refer to Tencent FATE Framework for more examples of the usage of this special loss function of VFL:
https://github.com/FederatedAI/FATE/blob/master/python/federatedml/nn/hetero_nn/model/hetero_nn_bottom_model.py