jordipons/musicnn

Convert model to TensorFlow Lite

wo80 opened this issue · 2 comments

wo80 commented

I'm trying to convert the MSD_musicnn model to tflite format using TensorFlow 2.15.1 on Debian 12. I have no experience training models, just doing inference, so it's hard for me to tell what's wrong my approach. I'm basically copying parts of the extractor.py code and saving the model with tf.compat.v1.saved_model.simple_save, then trying to convert the saved model.

Here's the Python code:

import os
import sys
import tensorflow as tf
import tensorflow.saved_model
import models
import configuration as config

model = 'MSD_musicnn'
n_frames = 187

labels = config.MSD_LABELS
num_classes = len(labels)

with tf.name_scope('model'):
    x = tf.compat.v1.placeholder(tf.float32, [None, n_frames, config.N_MELS])
    is_training = tf.compat.v1.placeholder(tf.bool)
    y, timbral, temporal, cnn1, cnn2, cnn3, mean_pool, max_pool, penultimate = models.define_model(x, is_training, model, num_classes)
    normalized_y = tf.nn.sigmoid(y)

# tensorflow: loading model
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
saver = tf.compat.v1.train.Saver()
saver.restore(sess, os.path.dirname(__file__)+'/'+model+'/')

# Saving
inputs = {"model/Placeholder_1": x}
outputs = {"model/Sigmoid": normalized_y, "model/dense_1/BiasAdd": y, "model/dense/BiasAdd": penultimate}
tf.compat.v1.saved_model.simple_save(sess, './msd-musicnn-3', inputs, outputs)

I'm executing the code in the musicnn directory and it successfully produces a saved model. But when I try to run

tflite_convert --saved_model_dir=./msd-musicnn-3 --output_file=./msd-musicnn-3.tflite

I get the following error:

2024-04-06 09:10:54.132240: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Traceback (most recent call last):
  File "/home/christian/tf/bin/tflite_convert", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/lite/python/tflite_convert.py", line 690, in main
    app.run(main=run_main, argv=sys.argv[:1])
  File "/home/christian/tf/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/christian/tf/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/lite/python/tflite_convert.py", line 673, in run_main
    _convert_tf2_model(tflite_flags)
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/lite/python/tflite_convert.py", line 274, in _convert_tf2_model
    converter = lite.TFLiteConverterV2.from_saved_model(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 2087, in from_saved_model
    saved_model = _load(saved_model_dir, tags)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/saved_model/load.py", line 912, in load
    result = load_partial(export_dir, None, tags, options)["root"]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/saved_model/load.py", line 1071, in load_partial
    root = load_v1_in_v2.load(
           ^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/saved_model/load_v1_in_v2.py", line 309, in load
    result = loader.load(
             ^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/saved_model/load_v1_in_v2.py", line 290, in load
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/saved_model/load_v1_in_v2.py", line 185, in _extract_signatures
    signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/eager/wrap_function.py", line 344, in prune
    lift_map = lift_to_graph.lift_to_graph(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/eager/lift_to_graph.py", line 253, in lift_to_graph
    sources.update(op_selector.map_subgraph(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christian/tf/lib/python3.11/site-packages/tensorflow/python/ops/op_selector.py", line 417, in map_subgraph
    raise UnliftableError(
tensorflow.python.ops.op_selector.UnliftableError: A SavedModel signature needs an input for each placeholder the signature's outputs use. An output for signature 'serving_default' depends on a placeholder which is not an input (i.e. the placeholder is not fed a value).

Unable to lift tensor <tf.Tensor 'model/Sigmoid:0' shape=(None, 50) dtype=float32> because it depends transitively on placeholder <tf.Operation 'model/Placeholder_1' type=Placeholder> via at least one path, e.g.:

model/Sigmoid (Sigmoid)
 <- model/dense_1/BiasAdd (BiasAdd)
 <- model/dense_1/MatMul (MatMul)
 <- model/dropout_1/cond/Identity (Identity)
 <- model/dropout_1/cond (If)
 <- model/batch_normalization_10/batchnorm/add_1 (AddV2)
 <- model/batch_normalization_10/batchnorm/sub (Sub)
 <- model/batch_normalization_10/batchnorm/mul_2 (Mul)
 <- model/batch_normalization_10/batchnorm/mul (Mul)
 <- model/batch_normalization_10/batchnorm/Rsqrt (Rsqrt)
 <- model/batch_normalization_10/batchnorm/add (AddV2)
 <- model/batch_normalization_10/cond_1/Identity (Identity)
 <- model/batch_normalization_10/cond_1 (If)
 <- model/batch_normalization_10/moments/Squeeze_1 (Squeeze)
 <- model/batch_normalization_10/moments/variance (Mean)
 <- model/batch_normalization_10/moments/SquaredDifference (SquaredDifference)
 <- model/batch_normalization_10/moments/StopGradient (StopGradient)
 <- model/batch_normalization_10/moments/mean (Mean)
 <- model/dense/Relu (Relu)
 <- model/dense/BiasAdd (BiasAdd)
 <- model/dense/MatMul (MatMul)
 <- model/dropout/cond/Identity (Identity)
 <- model/dropout/cond (If)
 <- model/batch_normalization_9/batchnorm/add_1 (AddV2)
 <- model/batch_normalization_9/batchnorm/sub (Sub)
 <- model/batch_normalization_9/batchnorm/mul_2 (Mul)
 <- model/batch_normalization_9/batchnorm/mul (Mul)
 <- model/batch_normalization_9/batchnorm/Rsqrt (Rsqrt)
 <- model/batch_normalization_9/batchnorm/add (AddV2)
 <- model/batch_normalization_9/cond_1/Identity (Identity)
 <- model/batch_normalization_9/cond_1 (If)
 <- model/batch_normalization_9/moments/Squeeze_1 (Squeeze)
 <- model/batch_normalization_9/moments/variance (Mean)
 <- model/batch_normalization_9/moments/SquaredDifference (SquaredDifference)
 <- model/batch_normalization_9/moments/StopGradient (StopGradient)
 <- model/batch_normalization_9/moments/mean (Mean)
 <- model/flatten/Reshape (Reshape)
 <- model/concat_2 (ConcatV2)
 <- model/moments/Squeeze (Squeeze)
 <- model/moments/mean (Mean)
 <- model/concat_1 (ConcatV2)
 <- model/Add_1 (AddV2)
 <- model/Add (AddV2)
 <- model/transpose (Transpose)
 <- model/batch_normalization_6/cond/Identity (Identity)
 <- model/batch_normalization_6/cond (If)
 <- model/conv2d_5/Relu (Relu)
 <- model/conv2d_5/BiasAdd (BiasAdd)
 <- model/conv2d_5/Conv2D (Conv2D)
 <- model/Pad_1 (Pad)
 <- model/ExpandDims_1 (ExpandDims)
 <- model/concat (ConcatV2)
 <- model/Squeeze_4 (Squeeze)
 <- model/max_pooling2d_4/MaxPool (MaxPool)
 <- model/batch_normalization_5/cond/Identity (Identity)
 <- model/batch_normalization_5/cond (If)
 <- model/conv2d_4/Relu (Relu)
 <- model/conv2d_4/BiasAdd (BiasAdd)
 <- model/conv2d_4/Conv2D (Conv2D)
 <- model/batch_normalization/cond/Identity (Identity)
 <- model/batch_normalization/cond (If)
 <- model/batch_normalization/cond/Squeeze (Squeeze)
 <- model/Placeholder_1 (Placeholder)

Not sure what to make out of this. It's complaining about missing model/Placeholder_1, but that's what I pass as input.

Any help would be appreciated, thanks!

wo80 commented

Ok, setting is_training to False at least lets me convert the model. Now let's see if it is anything good...

wo80 commented

Output of the tflite model and the original model match 🥳