vilassn/whisper_android

IllegalArgumentException: Internal error: Failed to run on the given Interpreter

Coder-HuangBH opened this issue · 2 comments

2024-04-29 09:48:50.970 24207-24753 Whisper com.whispertflite E Error...
java.lang.IllegalArgumentException: Internal error: Failed to run on the given Interpreter: tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true.
tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true.
tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true.
tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true.
tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T,
at org.tensorflow.lite.NativeInterpreterWrapper.run(Native Method)
at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:247)
at org.tensorflow.lite.InterpreterImpl.runForMultipleInputsOutputs(InterpreterImpl.java:107)
at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:80)
at org.tensorflow.lite.InterpreterImpl.run(InterpreterImpl.java:100)
at org.tensorflow.lite.Interpreter.run(Interpreter.java:80)
at com.whispertflite.engine.WhisperEngine.runInference(WhisperEngine.java:147)
at com.whispertflite.engine.WhisperEngine.transcribeFile(WhisperEngine.java:74)
at com.whispertflite.asr.Whisper.threadFunction(Whisper.java:129)
at com.whispertflite.asr.Whisper.lambda$start$0$com-whispertflite-asr-Whisper(Whisper.java:76)
at com.whispertflite.asr.Whisper$$ExternalSyntheticLambda0.run(Unknown Source:2)
at java.lang.Thread.run(Thread.java:930)

The only difference between success and failure is the tflite file,This is their parameter print :
success Input Tensor Dump ===>
2024-04-29 09:50:13.724 24920-25027 WhisperEngineJava com.whispertflite D shape.length: 3
2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[0]: 1
2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[1]: 80
2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[2]: 3000
2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D dataType: FLOAT32
2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D name: serving_default_input_ids:0
2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D numBytes: 960000
2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D index: 0
2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D numDimensions: 3
2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D numElements: 240000
2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D shapeSignature.length: 3
2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0
2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0
2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D ==================================================================
2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D Output Tensor Dump ===>
2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape.length: 2
2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape[0]: 1
2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape[1]: 448
2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D dataType: INT32
2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D name: StatefulPartitionedCall:0
2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numBytes: 1792
2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D index: 1047
2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numDimensions: 2
2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numElements: 448
2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D shapeSignature.length: 2
2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0
2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0
2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D ==================================================================

failed Input Tensor Dump ===>
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape.length: 3
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[0]: 1
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[1]: 80
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[2]: 3000
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D dataType: FLOAT32
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D name: serving_default_input_ids:0
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numBytes: 960000
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D index: 0
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numDimensions: 3
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numElements: 240000
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shapeSignature.length: 3
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D ==================================================================
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D Output Tensor Dump ===>
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape.length: 2
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[0]: 1
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[1]: 451
2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D dataType: INT32
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D name: StatefulPartitionedCall:0
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numBytes: 1804
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D index: 559
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numDimensions: 2
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numElements: 451
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D shapeSignature.length: 2
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0
2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D ==================================================================

The script for generating the failed tflite file is as follows:

import tensorflow as tf
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

whisperPath = "openai/whisper-tiny.en"
saved_model_dir = 'path/to/tf_whisper_saved'

tflite_model_path = 'path/to/whisper111.tflite'

feature_extractor = WhisperFeatureExtractor.from_pretrained(whisperPath)
tokenizer = WhisperTokenizer.from_pretrained(whisperPath, predict_timestamps=True)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained(whisperPath, from_pt=True)

Loading dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code = True)

inputs = feature_extractor(
ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features

Generating Transcription

generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0])
print(transcription)
model.save(saved_model_dir)

Convert the model

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

Save the model

with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

class GenerateModel(tf.Module):
def init(self, model):
super(GenerateModel, self).init()
self.model = model

@tf.function(
# shouldn't need static batch size, but throws exception without it (needs to be fixed)
input_signature=[
tf.TensorSpec((1, 80, 3000), tf.float32, name="input_ids"),
],
)
def serving(self, input_features):
outputs = self.model.generate(
input_features,
max_new_tokens=450, #change as needed
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}

saved_model_dir = '/content/tf_whisper_saved'

tflite_model_path = 'whisper-tiny.en.tflite'

tflite_model_path = 'path/to/whisper222.tflite'

tflite_model_path = 'path/to/whisper_vi222.tflite'

generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

Convert the model

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

Save the model

with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

Does apk work properly? Or same issue with apk also?

I got pretty much the same result. Your apk and the tflite models work locally (inference in python works well, the result of the transcription is correct). Would you share @vilassn the python packages configuration that you used in the apk? (python3.8, transformers, tensorflow, numpy, etc.)

Thanks!