What's the difference between get_function() and grads_check()
drcut opened this issue · 1 comments
Hi,
I find in the master branch, in file checkmate/tf2/wrapper.py, there are two functions which seems have the same meaning: get_function(model, input_shape, label_shape, optimizer, loss) and grads_check(data, label). What's the difference between them?
Thanks
Hi @drcut! Thanks for trying out our repository. I apologize about the delay in reply.
checkmate/checkmate/tf2/wrapper.py
Lines 51 to 60 in e88a3d9
In this case, grads_check is a function that we then trace using TensorFlow for static execution. The resulting function grads_check needs to have just two arguments, data and label, to preserve semantics on how the tf.Function is then used. However, we also want to pass along extra data such as the loss, optimizer and type/shape information. Therefore, I wrap grads_check with get_function in order to retain a reference to the arguments.
Let me know if this answers your question.
Thanks,
Paras