The problem of str_acc compution.
jj199603 opened this issue · 4 comments
Hello, when I first read the source code, I had this problem.
The source code of computing str_acc and bit_acc is the follows:
''
def get_secret_acc(secret_true,secret_pred):
with tf.variable_scope("acc"):
secret_pred = tf.round(secret_pred)
correct_pred = tf.count_nonzero(secret_pred - secret_true, axis=1)
str_acc = 1.0 - tf.count_nonzero(correct_pred - tf.to_int64(tf.shape(secret_pred)[1])) / tf.size(correct_pred, out_type=tf.int64)
bit_acc = tf.reduce_sum(correct_pred) / tf.size(secret_pred, out_type=tf.int64)
return bit_acc, str_acc
''
why the line 4 is tf.count_nonzero(secret_pred - secret_true, axis=1)? This seems to compute the error bit number of each batch. However, the str_acc and bit_acc use it as the correct number. I think it maybe a typo, but the traning log of tensorborad shows a reasonable result.
The tf.count_nonzero
is over the message axis so it will be a tensor of length batch_size. The values of this tensor will range from 0 to message_size.
The str_acc
computes what percentage of these correct_preds
equal the message_size (ie the entire message is correct).
The bit_acc
is computing the equivalent of mean(correct_preds / message_size)
but in an admittingly less readable way.
thanks for your reply! I fully understand your explanation about the str_acc and bit_acc. but I don't understand why you use (secret_pred - secret_true) as the input of nonzeros. if the pred string is equal to the ground truth, this tensor should be a zero tensor of (batch size, secret size). so the correct_pred represent the error number of each input string? maybe I misunderstand. Thanks.
You are right! I pushed a bugfix that addresses this. It turns out there were two issues, the one you found plus the lack of a sigmoid. Remarkably this doesn't seem to dramatically change the values.
Thanks! This is a great work!