'Pruning in Keras' Example: Unable to fine-tune pruned model on GPU (TF 2.9.2)
swapnilsayansaha opened this issue · 5 comments
Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.
Describe the bug
Most likely, sparse operators such as PruneLowMagntiude
cannot be loaded and operated on on a GPU.
System information
TensorFlow version (installed from source or binary): Binary, 2.9.2, CUDA: 11.6
GPU: Nvidia RTX 3090 24 GB
OS: Ubuntu 20.04
TensorFlow Model Optimization version (installed from source or binary): Binary, 0.7.3
Python version: 3.8
Describe the expected behavior and the current behavior
Issue described here: tensorflow/tensorflow#58499
https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras
The example code for pruning should work as it is out of the box. However, fine-tuning the pruned model doesn't work on GPU. I made a workaround to solve it by forcing the fine-tuning of the pruned model on CPU:
with tf.device('/cpu:0'):
model_for_pruning.fit(train_images, train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
The unpruned model can train fine on the GPU, it's not a problem with CUDA drivers, so please do not suggest reconfiguring a new conda/venv environment
The following error occurs without the with tf.device('/cpu:0'):
:
UnknownError Traceback (most recent call last)
Cell In [5], line 8
1 logdir = tempfile.mkdtemp()
3 callbacks = [
4 tfmot.sparsity.keras.UpdatePruningStep(),
5 tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
6 ]
----> 8 model_for_pruning.fit(train_images, train_labels,
9 batch_size=batch_size, epochs=epochs, validation_split=validation_split,
10 callbacks=callbacks)
File ~/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
File ~/swapnil_debug_2/lib/python3.8/site-packages/tensorflow/python/eager/execute.py:54, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
52 try:
53 ctx.ensure_initialized()
---> 54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
57 if name is not None:
UnknownError: Graph execution error:
Detected at node 'sequential/prune_low_magnitude_conv2d/FloorMod' defined at (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/traitlets/config/application.py", line 982, in launch_instance
app.start()
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 712, in start
self.io_loop.start()
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
self._run_once()
File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
handle._run()
File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
self._context.run(self._callback, *self._args)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
await self.process_one()
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 499, in process_one
await dispatch(*args)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
await result
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
reply_content = await reply_content
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
res = shell.run_cell(
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
return super().run_cell(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
result = self._run_cell(
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
return runner(coro)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_2426111/471826281.py", line 8, in <module>
model_for_pruning.fit(train_images, train_labels,
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1409, in fit
tmp_logs = self.train_function(iterator)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1051, in train_function
return step_function(self, iterator)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1040, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1030, in run_step
outputs = model.train_step(data)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 889, in train_step
y_pred = self(x, training=True)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
return super().__call__(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/sequential.py", line 374, in call
return super(Sequential, self).call(inputs, training=training, mask=mask)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/functional.py", line 458, in call
return self._run_internal_graph(
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
outputs = node.layer(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 280, in call
update_mask = utils.smart_cond(training, add_update, no_op)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 50, in smart_cond
if isinstance(pred, variables.Variable):
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 54, in smart_cond
pred, true_fn=true_fn, false_fn=false_fn, name=name)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 268, in add_update
with tf.control_dependencies(
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 310, in conditional_mask_update
return tf.distribute.get_replica_context().merge_call(
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 307, in mask_update_distributed
return tf.cond(maybe_update_masks(), update_distributed, no_update)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 260, in maybe_update_masks
if self._sparsity_m_by_n:
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 264, in maybe_update_masks
return self._pruning_schedule(self._step_fn())[0]
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 246, in __call__
sparsity)
File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 61, in _should_prune_in_step
is_pruning_turn = tf.math.equal(
Node: 'sequential/prune_low_magnitude_conv2d/FloorMod'
JIT compilation failed.
[[{{node sequential/prune_low_magnitude_conv2d/FloorMod}}]] [Op:__inference_train_function_34086]
Code to reproduce the issue
https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras
Screenshots
If applicable, add screenshots to help explain your problem.
Additional context
The problem isn't too serious as I can train the unpruned model on the GPU for for example 200 epochs, save its weights, load it, add the necessary code to prune the model, then fine-tune it for example for 10 epochs on the CPU. However, it's worth looking into why the fine-tuning cannot happen on the GPU.
Hi, it seems like an issue of floormod op on GPU rather than Pruning API's issue. It is weird since the similar bug is already fixed an year ago - tensorflow/tensorflow#46887
Could you double check your tensorflow version? If it exists in recent tensorflow version, we may need to reopen the above issue.
Tf version is tf2.9.2 (GPU)
similar bug on win10 tf2.10.0 with floormod
Having the same problem on RTX 3090 with tensorflow 2.10. Can't even run PQAT because of the issue with pruning using GPU
I had the same problem.
So I ran the following and got an error that libdevice.10.bc was not found.
@tf.function(jit_compile=True)
def floormod(a, b):
return tf.math.floormod(a, b)
floormod(tf.constant(1.), tf.constant(1.))
tensorflow.python.framework.errors_impl.InternalError: libdevice not found at ./libdevice.10.bc [Op:__inference_floormod_49]
I added the following to the top of the program and it worked.
import os
os.environ["XLA_FLAGS"]='--xla_gpu_cuda_data_dir=/path/to/cuda'
I hope this helps.