CyberZHG/keras-radam

Issue about the dtype

TianrongChen opened this issue · 3 comments

Hi! Thanks for your implementation! I would like to use your repo in my own code, but find that the dtype has conflict. My code is written in tf.float 64. When I used your code, it reports the bug that Op has type float64 that does not match type float 32. Is there any way to solve this problem?
Thanks again for your elegant implementation!

Try this?

from tensorflow.python.keras import backend as K

K.set_floatx('float64')

Tried. It does not work. It shows the same bug, I attached part of code here:

    sqrd_l2_norm 		= tf.reduce_sum(tf.stack(norm_list), keepdims=False)
    grads 				= tf.gradients(self._loss + weigth_decay_param*sqrd_l2_norm , trainable_variables)
    clipped_grads, _     = tf.clip_by_global_norm(grads, self._config.max_gradient_norm)
    optimizer 			 = RAdamOptimizer(learning_rate=1e-3)
    self._train_ops		 = optimizer.apply_gradients(zip(clipped_grads, trainable_variables), global_step=global_step, name='train_step')

The error occurs at the last line. Do you have any insight? Thanks!!!!

stale commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.