ck37/coral-ordinal

Loss function pug

Omar3esam opened this issue · 1 comments

updated loss function to adapt with the tf2.2 version in label_to_levels function

label_vec = tf.repeat(1, tf.cast(tf.argmax(label), tf.int32)) 
    
    # This line requires that label values begin at 0. If they start at a higher
    # value it will yield an error.
num_zeros = self.num_classes - 1 - tf.cast(tf.argmax(label), tf.int32)
    
zero_vec = tf.zeros(shape = (num_zeros), dtype = tf.int32)
    
levels = tf.concat([label_vec, zero_vec], axis = 0)

return tf.cast(levels, tf.float32)

With latest merge, the label_to_levels function has been re-implemented from scratch using tf.sequence_mask() (vectorized instead of relying on tf.map_fn()). Makes implementation update here obsolete.