Loss function pug
Omar3esam opened this issue · 1 comments
Omar3esam commented
updated loss function to adapt with the tf2.2 version in label_to_levels function
label_vec = tf.repeat(1, tf.cast(tf.argmax(label), tf.int32))
# This line requires that label values begin at 0. If they start at a higher
# value it will yield an error.
num_zeros = self.num_classes - 1 - tf.cast(tf.argmax(label), tf.int32)
zero_vec = tf.zeros(shape = (num_zeros), dtype = tf.int32)
levels = tf.concat([label_vec, zero_vec], axis = 0)
return tf.cast(levels, tf.float32)
gmgeorg commented
With latest merge, the label_to_levels function has been re-implemented from scratch using tf.sequence_mask()
(vectorized instead of relying on tf.map_fn()
). Makes implementation update here obsolete.