tensorflow/mesh

Finetuning a `bfloat16` checkpoint with `float32`

saareliad opened this issue · 0 comments

I'm trying to fine-tuning a released T5 checkpoint in float32,
but I get the following error:

2020-09-03 16:33:42.380962: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at save_restore_v2_ops.cc:184 : Invalid argument: tensor_name =
/block_018/layer_002/layer_norm/scale; expected dtype float does not equal original dtype bfloat16

Is what I'm trying to do supported? These are the relevant parts I set:
--gin_param="get_variable_dtype.activation_dtype = 'float32'"
--gin_param="get_variable_dtype.master_dtype = 'float32'"
--gin_param="get_variable_dtype.slice_dtype = 'float32'"
--gin_file="gs://t5-data/pretrained_models/3B/operative_config.gin"

(We explicitly want float32)