google-deepmind/sonnet

ValueError("None values not supported.") when using all_reduce on gradients with TPU

a3cel2 opened this issue · 1 comments

Hi, I am trying to use SONNET to adapt ENFORMER training code provided here:
https://github.com/deepmind/deepmind-research/blob/master/enformer/enformer-training.ipynb

However, the authors recommend a TPU instead of a GPU. I am following this guide as a template to do distributed training, trying to adapt it to a TPU workflow:
https://colab.research.google.com/github/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb

My entire workflow is here:
https://colab.research.google.com/drive/1nQklf0EB9ME9b90hX2GJAau1aQP8nXXb#scrollTo=i0lUnjHzDzp3

What it boils down to is how I define a training step, if I define without an all-reduce step it works fine and trains:

def create_step_function_working(model, optimizer, head):
  def step(sequence, target, optimizer_clip_norm_global = 0.2):
    with tf.GradientTape() as tape:
      outputs = model(sequence, is_training=True)['human']

      loss = tf.reduce_mean(
              tf.keras.losses.poisson(target, outputs))

    grads = tape.gradient(loss, model.trainable_variables)

    #replica_ctx = tf.distribute.get_replica_context()
    #grads = replica_ctx.all_reduce("mean", grads)

    optimizer.apply(grads, model.trainable_variables)
    return loss

  @tf.function
  def train_step(batch, head):
    per_replica_loss = tpu_strategy.run(step, args=(batch['sequence'], batch['target']))
    return tpu_strategy.reduce("sum", per_replica_loss, axis=None)
  return train_step

However, it results in undesirable behaviour in that the replicas are no longer in sync! Even after 20 epochs, there are slight differences in the replica weights that get more exaggerated as training goes on (I suspect this is because of dropout in the model):

# Train the model

with tpu_strategy.scope():
  learning_rate = tf.Variable(0., trainable=False, name='learning_rate')
  optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
  num_warmup_steps = 5000
  target_learning_rate = 0.0005

  model = enformer.Enformer(channels=1536 // 4, #Use 4x fewer channels to train faster.
                            num_heads=8,
                            num_transformer_layers=11,
                            pooling_type='max')

  train_step_human = create_step_function_working(model, optimizer, 'human')

  steps_per_epoch = 10
  num_epochs = 2

  data_it = iter(human_mouse_dataset)
  global_step = 0
  for epoch_i in range(num_epochs):
    for i in tqdm(range(steps_per_epoch)):
      global_step += 1

      if global_step > 1:
        learning_rate_frac = tf.math.minimum(
            1.0, global_step / tf.math.maximum(1.0, num_warmup_steps))      
        learning_rate.assign(target_learning_rate * learning_rate_frac)

      batch_human, batch_mouse = next(data_it)

      loss_human = train_step_human(batch=batch_human, head='human')

>>> model.trainable_variables[0]
0: <tf.Variable 'enformer/heads/head_human/head_human/linear/b:0' shape=(5313,) dtype=float32, numpy=
array([-2.0701702e-05, -2.0692447e-05, -2.0778976e-05, ...,
       -1.9857325e-05, -1.8381259e-05, -1.9378911e-05], dtype=float32)>,
  1: <tf.Variable 'enformer/heads/head_human/head_human/linear/b/replica_1:0' shape=(5313,) dtype=float32, numpy=
array([-2.0712672e-05, -2.0690546e-05, -2.0791686e-05, ...,
       -1.9846719e-05, -1.8348546e-05, -1.9401912e-05], dtype=float32)>

So, I can add an all-reduce as in the MultiGPU example to fix this:

#Try it with all reduce
def create_step_function_not_working(model, optimizer, head):
  def step(sequence, target, optimizer_clip_norm_global = 0.2):
    with tf.GradientTape() as tape:
      outputs = model(sequence, is_training=True)['human']

      loss = tf.reduce_mean(
              tf.keras.losses.poisson(target, outputs))

    grads = tape.gradient(loss, model.trainable_variables)

    replica_ctx = tf.distribute.get_replica_context()
    grads = replica_ctx.all_reduce("mean", grads)

    optimizer.apply(grads, model.trainable_variables)
    return loss

  @tf.function
  def train_step(batch, head):
    per_replica_loss = tpu_strategy.run(step, args=(batch['sequence'], batch['target']))
    return tpu_strategy.reduce("sum", per_replica_loss, axis=None)
  return train_step

But this results in an error if I run the above block with this replacement:

  0%|          | 0/10 [00:48<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-31565318e900> in <module>()
     17       batch_human, batch_mouse = next(data_it)
     18 
---> 19       loss_human = train_step_human(batch=batch_human, head='human')

8 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    992           except Exception as e:  # pylint:disable=broad-except
    993             if hasattr(e, "ag_error_metadata"):
--> 994               raise e.ag_error_metadata.to_exception(e)
    995             else:
    996               raise

ValueError: in user code:

    <ipython-input-16-140fb8e0bfc4>:20 train_step  *
        per_replica_loss = tpu_strategy.run(step, args=(batch['sequence'], batch['target']))
    <ipython-input-16-140fb8e0bfc4>:13 step  *
        grads = replica_ctx.all_reduce("mean", grads)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3212 all_reduce  **
        return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/custom_gradient.py:309 __call__
        return self._d(self._f, a, k)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/custom_gradient.py:265 decorated
        return _graph_mode_decorator(wrapped, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/custom_gradient.py:372 _graph_mode_decorator
        args = nest.map_structure(ops.convert_to_tensor, args)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py:869 map_structure
        structure[0], [func(*x) for x in entries],
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py:869 <listcomp>
        structure[0], [func(*x) for x in entries],
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/profiler/trace.py:163 wrapped
        return func(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:1566 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/constant_op.py:346 _constant_tensor_conversion_function
        return constant(v, dtype=dtype, name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/constant_op.py:272 constant
        allow_broadcast=True)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/constant_op.py:290 _constant_impl
        allow_broadcast=allow_broadcast))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/tensor_util.py:445 make_tensor_proto
        raise ValueError("None values not supported.")

    ValueError: None values not supported.

I'm not sure if I've missed something or if this is a genuine bug. I was not able to find enough documentation on Sonnet to solve this myself, and it is incompatible with the vanilla tf.distribute.experimental.TPUStrategy. If this is indeed not a bug, perhaps releasing the code for this specific example (TPU training for the ENFORMER model) would also help others in my situation

Sorry, this was my mistake. In this particular model, not all graph variables are connected in each step, so need to ensure there are no None values in the gradients:
grads = tape.gradient(loss, model.trainable_variables, unconnected_gradients=tf.UnconnectedGradients.ZERO)