parasj/checkmate

What's the difference between get_function() and grads_check()

drcut opened this issue · 1 comments

drcut commented

Hi,
I find in the master branch, in file checkmate/tf2/wrapper.py, there are two functions which seems have the same meaning: get_function(model, input_shape, label_shape, optimizer, loss) and grads_check(data, label). What's the difference between them?
Thanks

Hi @drcut! Thanks for trying out our repository. I apologize about the delay in reply.

def get_function(model, input_shape, label_shape, optimizer, loss):
@tf.function
def grads_check(data, label):
with tf.GradientTape() as check_tape:
predictions = model(data)
loss_val = loss(label, predictions)
gradients = check_tape.gradient(loss_val, model.trainable_variables)
return predictions, loss_val, gradients
return grads_check

In this case, grads_check is a function that we then trace using TensorFlow for static execution. The resulting function grads_check needs to have just two arguments, data and label, to preserve semantics on how the tf.Function is then used. However, we also want to pass along extra data such as the loss, optimizer and type/shape information. Therefore, I wrap grads_check with get_function in order to retain a reference to the arguments.

Let me know if this answers your question.

Thanks,
Paras