tensorflow/model-optimization

'Pruning in Keras' Example: Unable to fine-tune pruned model on GPU (TF 2.9.2)

swapnilsayansaha opened this issue · 5 comments

Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.

Describe the bug
Most likely, sparse operators such as PruneLowMagntiude cannot be loaded and operated on on a GPU.

System information

TensorFlow version (installed from source or binary): Binary, 2.9.2, CUDA: 11.6

GPU: Nvidia RTX 3090 24 GB

OS: Ubuntu 20.04

TensorFlow Model Optimization version (installed from source or binary): Binary, 0.7.3

Python version: 3.8

Describe the expected behavior and the current behavior
Issue described here: tensorflow/tensorflow#58499
https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras
The example code for pruning should work as it is out of the box. However, fine-tuning the pruned model doesn't work on GPU. I made a workaround to solve it by forcing the fine-tuning of the pruned model on CPU:

with tf.device('/cpu:0'):
   model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)

The unpruned model can train fine on the GPU, it's not a problem with CUDA drivers, so please do not suggest reconfiguring a new conda/venv environment

The following error occurs without the with tf.device('/cpu:0')::

UnknownError                              Traceback (most recent call last)
Cell In [5], line 8
      1 logdir = tempfile.mkdtemp()
      3 callbacks = [
      4   tfmot.sparsity.keras.UpdatePruningStep(),
      5   tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
      6 ]
----> 8 model_for_pruning.fit(train_images, train_labels,
      9                   batch_size=batch_size, epochs=epochs, validation_split=validation_split,
     10                   callbacks=callbacks)

File ~/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File ~/swapnil_debug_2/lib/python3.8/site-packages/tensorflow/python/eager/execute.py:54, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     52 try:
     53   ctx.ensure_initialized()
---> 54   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     55                                       inputs, attrs, num_outputs)
     56 except core._NotOkStatusException as e:
     57   if name is not None:

UnknownError: Graph execution error:

Detected at node 'sequential/prune_low_magnitude_conv2d/FloorMod' defined at (most recent call last):
    File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/traitlets/config/application.py", line 982, in launch_instance
      app.start()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_2426111/471826281.py", line 8, in <module>
      model_for_pruning.fit(train_images, train_labels,
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1051, in train_function
      return step_function(self, iterator)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 889, in train_step
      y_pred = self(x, training=True)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/sequential.py", line 374, in call
      return super(Sequential, self).call(inputs, training=training, mask=mask)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/functional.py", line 458, in call
      return self._run_internal_graph(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 280, in call
      update_mask = utils.smart_cond(training, add_update, no_op)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 50, in smart_cond
      if isinstance(pred, variables.Variable):
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 54, in smart_cond
      pred, true_fn=true_fn, false_fn=false_fn, name=name)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 268, in add_update
      with tf.control_dependencies(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 310, in conditional_mask_update
      return tf.distribute.get_replica_context().merge_call(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 307, in mask_update_distributed
      return tf.cond(maybe_update_masks(), update_distributed, no_update)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 260, in maybe_update_masks
      if self._sparsity_m_by_n:
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 264, in maybe_update_masks
      return self._pruning_schedule(self._step_fn())[0]
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 246, in __call__
      sparsity)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 61, in _should_prune_in_step
      is_pruning_turn = tf.math.equal(
Node: 'sequential/prune_low_magnitude_conv2d/FloorMod'
JIT compilation failed.
	 [[{{node sequential/prune_low_magnitude_conv2d/FloorMod}}]] [Op:__inference_train_function_34086]

Code to reproduce the issue
https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras

Screenshots
If applicable, add screenshots to help explain your problem.

Additional context
The problem isn't too serious as I can train the unpruned model on the GPU for for example 200 epochs, save its weights, load it, add the necessary code to prune the model, then fine-tune it for example for 10 epochs on the CPU. However, it's worth looking into why the fine-tuning cannot happen on the GPU.

Hi, it seems like an issue of floormod op on GPU rather than Pruning API's issue. It is weird since the similar bug is already fixed an year ago - tensorflow/tensorflow#46887

Could you double check your tensorflow version? If it exists in recent tensorflow version, we may need to reopen the above issue.

Tf version is tf2.9.2 (GPU)

similar bug on win10 tf2.10.0 with floormod

puelon commented

Having the same problem on RTX 3090 with tensorflow 2.10. Can't even run PQAT because of the issue with pruning using GPU

gyojir commented

I had the same problem.
So I ran the following and got an error that libdevice.10.bc was not found.

@tf.function(jit_compile=True)
def floormod(a, b):
  return tf.math.floormod(a, b)

floormod(tf.constant(1.), tf.constant(1.)) 
tensorflow.python.framework.errors_impl.InternalError: libdevice not found at ./libdevice.10.bc [Op:__inference_floormod_49]

I added the following to the top of the program and it worked.

import os
os.environ["XLA_FLAGS"]='--xla_gpu_cuda_data_dir=/path/to/cuda'

I hope this helps.