tancik/StegaStamp

The problem of str_acc compution.

jj199603 opened this issue · 4 comments

Hello, when I first read the source code, I had this problem.
The source code of computing str_acc and bit_acc is the follows:
''
def get_secret_acc(secret_true,secret_pred):
with tf.variable_scope("acc"):
secret_pred = tf.round(secret_pred)
correct_pred = tf.count_nonzero(secret_pred - secret_true, axis=1)
str_acc = 1.0 - tf.count_nonzero(correct_pred - tf.to_int64(tf.shape(secret_pred)[1])) / tf.size(correct_pred, out_type=tf.int64)
bit_acc = tf.reduce_sum(correct_pred) / tf.size(secret_pred, out_type=tf.int64)
return bit_acc, str_acc
''
why the line 4 is tf.count_nonzero(secret_pred - secret_true, axis=1)? This seems to compute the error bit number of each batch. However, the str_acc and bit_acc use it as the correct number. I think it maybe a typo, but the traning log of tensorborad shows a reasonable result.

The tf.count_nonzero is over the message axis so it will be a tensor of length batch_size. The values of this tensor will range from 0 to message_size.
The str_acc computes what percentage of these correct_preds equal the message_size (ie the entire message is correct).
The bit_acc is computing the equivalent of mean(correct_preds / message_size) but in an admittingly less readable way.

thanks for your reply! I fully understand your explanation about the str_acc and bit_acc. but I don't understand why you use (secret_pred - secret_true) as the input of nonzeros. if the pred string is equal to the ground truth, this tensor should be a zero tensor of (batch size, secret size). so the correct_pred represent the error number of each input string? maybe I misunderstand. Thanks.

You are right! I pushed a bugfix that addresses this. It turns out there were two issues, the one you found plus the lack of a sigmoid. Remarkably this doesn't seem to dramatically change the values.

Thanks! This is a great work!