Toy regression model Android inference
Deepakpy opened this issue · 8 comments
Hello i recently tried to use Android-kotlin example with the toy regression dataset and model. However there were error related to tensorflow.
Error:
2023-08-07 20:49:15.675 13218-13574 System.err flwr.android_client W java.lang.IllegalStateException: Internal error: Failed to run on the given Interpreter: tensorflow/lite/kernels/transpose.cc:59 op_context->perm->dims->data[0] != dims (2 != 3)
2023-08-07 20:49:15.676 13218-13574 System.err flwr.android_client W Node number 56 (TRANSPOSE) failed to prepare.
2023-08-07 20:49:15.677 13218-13574 System.err flwr.android_client W at org.tensorflow.lite.NativeSignatureRunnerWrapper.nativeInvoke(Native Method)
2023-08-07 20:49:15.686 13218-13574 System.err flwr.android_client W at org.tensorflow.lite.NativeSignatureRunnerWrapper.invoke(NativeSignatureRunnerWrapper.java:93)
2023-08-07 20:49:15.687 13218-13574 System.err flwr.android_client W at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:213)
2023-08-07 20:49:15.687 13218-13574 System.err flwr.android_client W at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:253)
2023-08-07 20:49:15.688 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.runSignatureLocked(FlowerClient.kt:211)
2023-08-07 20:49:15.689 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.training(FlowerClient.kt:162)
2023-08-07 20:49:15.689 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.access$training(FlowerClient.kt:22)
2023-08-07 20:49:15.690 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient$trainOneEpoch$1.invoke(FlowerClient.kt:144)
2023-08-07 20:49:15.704 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient$trainOneEpoch$1.invoke(FlowerClient.kt:141)
2023-08-07 20:49:15.705 13218-13574 System.err flwr.android_client W at kotlin.sequences.TransformingSequence$iterator$1.next(Sequences.kt:210)
2023-08-07 20:49:15.705 13218-13574 System.err flwr.android_client W at kotlin.sequences.SequencesKt___SequencesKt.toCollection(_Sequences.kt:787)
2023-08-07 20:49:15.705 13218-13574 System.err flwr.android_client W at kotlin.sequences.SequencesKt___SequencesKt.toMutableList(_Sequences.kt:817)
2023-08-07 20:49:15.705 13218-13574 System.err flwr.android_client W at kotlin.sequences.SequencesKt___SequencesKt.toList(_Sequences.kt:808)
2023-08-07 20:49:15.706 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.trainOneEpoch(FlowerClient.kt:145)
2023-08-07 20:49:15.706 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.fit(FlowerClient.kt:96)
2023-08-07 20:49:15.706 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.fit$default(FlowerClient.kt:89)
2023-08-07 20:49:15.706 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable.handleFitIns(FlowerServiceRunnable.kt:95)
2023-08-07 20:49:15.706 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable.handleMessage(FlowerServiceRunnable.kt:64)
2023-08-07 20:49:15.706 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:41)
2023-08-07 20:49:15.707 13218-13574 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:38)
2023-08-07 20:49:15.707 13218-13574 System.err flwr.android_client W at io.grpc.stub.ClientCalls$StreamObserverToCallListenerAdapter.onMessage(ClientCalls.java:478)
2023-08-07 20:49:15.707 13218-13574 System.err flwr.android_client W at io.grpc.internal.DelayedClientCall$DelayedListener.onMessage(DelayedClientCall.java:473)
2023-08-07 20:49:15.707 13218-13574 System.err flwr.android_client W at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInternal(ClientCallImpl.java:660)
2023-08-07 20:49:15.707 13218-13574 System.err flwr.android_client W at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInContext(ClientCallImpl.java:647)
2023-08-07 20:49:15.707 13218-13574 System.err flwr.android_client W at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
2023-08-07 20:49:15.708 13218-13574 System.err flwr.android_client W at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
2023-08-07 20:49:15.708 13218-13574 System.err flwr.android_client W at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1137)
2023-08-07 20:49:15.708 13218-13574 System.err flwr.android_client W at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:637)
2023-08-07 20:49:15.708 13218-13574 System.err flwr.android_client W at java.lang.Thread.run(Thread.java:1012)
The data is being loaded using below function:
suspend fun loadDatareg(
context: Context,
flowerClient: FlowerClient<Float2DArray, FloatArray>
) {
val csvFile = "reg_data/data.csv"
val inputStream = context.assets.open(csvFile)
val reader = CSVReader(InputStreamReader(inputStream))
val rows = reader.readAll()
reader.close()
val xTrain = Array(rows.size - 1) { FloatArray(2) } // -1 because skipping the header row
val yTrain = FloatArray(rows.size - 1)
for (i in 1 until rows.size) {
val row = rows[i]
xTrain[i - 1][0] = row[0].toFloat()
xTrain[i - 1][1] = row[1].toFloat()
yTrain[i - 1] = row[2].toFloat() // Update this to match your dataset structure
}
try {
flowerClient.addSample(xTrain, yTrain, true)
} catch (e: ExecutionException) {
throw RuntimeException("Failed to add sample to model", e.cause)
} catch (e: InterruptedException) {
// no-op
}
// return Pair(xTrain, yTrain)
}
This is the client:
private fun createFlowerClient0() {
val buffer1 = loadMappedAssetFile(this, "model/toy_regression.tflite")
val layersSizes = intArrayOf(8,4)
val sampleSpec = SampleSpec<Float2DArray, FloatArray>(
{ it.toTypedArray() },
{ it.toTypedArray() },
{ Array(it) { FloatArray(1) } },
::maxSquaredErrorLoss,
::placeholderAccuracy,
)
flowerClient1 = FlowerClient(buffer1, layersSizes, sampleSpec)
}
Any advice on what might be going wrong here while loading the model.
The features (xTrain
) should be FloatArray
because they are simply arrays of 2 numbers. That is, changing SampleSpec<Float2DArray, FloatArray>
to SampleSpec<FloatArray, FloatArray>
, and the same for FlowerClient
. yTrain
should be FloatArray(1)
. You also need to call flowerClient.addSample
in a loop.
Please try the above. Hope it works.
I could see that each input value is loaded as an array of size 11 which is the number of times the loop runs
Is this normal ?
Also it throws error in the eval part:
2023-08-09 09:16:37.924 8019-8199 Flower Service Runnable flwr.android_client D Handling EvaluateIns
2023-08-09 09:16:37.929 8019-8199 System.err flwr.android_client W java.lang.IllegalArgumentException: Array lengths cannot be 0.
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at org.tensorflow.lite.TensorImpl.computeNumDimensions(TensorImpl.java:358)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at org.tensorflow.lite.TensorImpl.computeShapeOf(TensorImpl.java:324)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at org.tensorflow.lite.TensorImpl.getInputShapeIfDifferent(TensorImpl.java:253)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:242)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:194)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:253)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.runSignatureLocked(FlowerClient.kt:203)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.inference(FlowerClient.kt:119)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerClient.evaluate(FlowerClient.kt:105)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable.handleEvaluateIns(FlowerServiceRunnable.kt:109)
2023-08-09 09:16:37.931 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable.handleMessage(FlowerServiceRunnable.kt:66)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:41)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:38)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at io.grpc.stub.ClientCalls$StreamObserverToCallListenerAdapter.onMessage(ClientCalls.java:478)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at io.grpc.internal.DelayedClientCall$DelayedListener.onMessage(DelayedClientCall.java:473)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInternal(ClientCallImpl.java:660)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInContext(ClientCallImpl.java:647)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1137)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:637)
2023-08-09 09:16:37.932 8019-8199 System.err flwr.android_client W at java.lang.Thread.run(Thread.java:1012)
Changes made:
suspend fun loadDatareg(
context: Context,
flowerClient: FlowerClient<FloatArray, FloatArray>
) {
val csvFile = "reg_data/data.csv"
val inputStream = context.assets.open(csvFile)
val reader = CSVReader(InputStreamReader(inputStream))
val rows = reader.readAll()
reader.close()
val xTrain = FloatArray(2) // -1 because skipping the header row
val yTrain = FloatArray(1)
for (i in 1 until rows.size) {
val row = rows[i]
xTrain[0] = row[0].toFloat()
xTrain[1] = row[1].toFloat()
yTrain[0] = row[2].toFloat() // Update this to match your dataset structure
try {
flowerClient.addSample(xTrain, yTrain, true)
} catch (e: ExecutionException) {
throw RuntimeException("Failed to add sample to model", e.cause)
} catch (e: InterruptedException) {
// no-op
}
}
}
private fun createFlowerClient0() {
val buffer1 = loadMappedAssetFile(this, "model/toy_regression.tflite")
val layersSizes = intArrayOf(8,4)
val sampleSpec = SampleSpec<FloatArray, FloatArray>(
{ it.toTypedArray() },
{ it.toTypedArray() },
{ Array(it) { FloatArray(1) } },
::maxSquaredErrorLoss,
::placeholderAccuracy,
)
flowerClient1 = FlowerClient(buffer1, layersSizes, sampleSpec)
}
I don't know which data structure your screenshot refers to. But, the content looks fine to me.
You are getting Array lengths cannot be 0
from TFLite in evaluation because your test dataset is empty.
For now, a workaround is to add your data also to the test dataset by flowerClient.addSample(xTrain, yTrain, false)
.
flowerClient.addSample(xTrain, yTrain, false)
My assumption was it should have below data loaded
7773,296.816,6765
9253,317.181,7923
15680,567.808,15933
6506,199.645,6526
14854,570.719,12374
6533,241.307,5358
6126,233.927,4373
16660,655.881,13995
13493,501.036,11339
12338,518.152,9174
1004,35.83,716
But it only has copied the last row values to the entire array
I expected it to have something like :
bottlenecks = {float[11][]@24579}
0 = {float[2]@24551} [7773,296.816]
1 = {float[2]@24551} [9253,317.181]
2 = {float[2]@24551} [15680,567.808]
3 = {float[2]@24551} [6506,199.645]
4 = {float[2]@24551} [14854,570.719]
5 = {float[2]@24551} [6533,241.307]
6 = {float[2]@24551} [6126,233.927]
7 = {float[2]@24551} [16660,655.881]
8 = {float[2]@24551} [13493,501.036]
9 = {float[2]@24551} [12338,518.152]
10 = {float[2]@24551} [1004.0, 35.83]
The actual values in the bottlenecks are :
bottlenecks = {float[11][]@24579}
0 = {float[2]@24551} [1004.0, 35.83]
1 = {float[2]@24551} [1004.0, 35.83]
2 = {float[2]@24551} [1004.0, 35.83]
3 = {float[2]@24551} [1004.0, 35.83]
4 = {float[2]@24551} [1004.0, 35.83]
5 = {float[2]@24551} [1004.0, 35.83]
6 = {float[2]@24551} [1004.0, 35.83]
7 = {float[2]@24551} [1004.0, 35.83]
8 = {float[2]@24551} [1004.0, 35.83]
9 = {float[2]@24551} [1004.0, 35.83]
10 = {float[2]@24551} [1004.0, 35.83]
I don't know which data structure your screenshot refers to. But, the content looks fine to me.
the data structure it loads into the traingsamples:
trainingSamples = {ArrayList@24544} size = 11
0 = {Sample@24752} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
bottleneck = {float[2]@24551} [1004.0, 35.83]
bottleneck {java.lang.Object}
label = {float[1]@24774} [716.0]
label {java.lang.Object}
shadow$klass = {Class@22767} "class dev.flower.flower_tflite.Sample"
shadow$monitor = 0
1 = {Sample@24753} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
bottleneck = {float[2]@24551} [1004.0, 35.83]
bottleneck {java.lang.Object}
label = {float[1]@24774} [716.0]
label {java.lang.Object}
shadow$klass = {Class@22767} "class dev.flower.flower_tflite.Sample"
shadow$monitor = 0
2 = {Sample@24754} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
3 = {Sample@24755} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
all the indexes have the last value of the for loop instead of having all the values.
fun addSample(
bottleneck: X, label: Y, isTraining: Boolean
) {
val samples = if (isTraining) trainingSamples else testSamples
val lock = if (isTraining) trainSampleLock else testSampleLock
lock.write {
samples.add(Sample(bottleneck, label))
}
}
this function along with adding a new value. it is also overwriting the previous array values.
You had a common programming mistake about "pass by reference." This should fix it:
- val xTrain = FloatArray(2)
- val yTrain = FloatArray(1)
for (i in 1 until rows.size) {
+ val xTrain = FloatArray(2)
+ val yTrain = FloatArray(1)
val row = rows[i]
xTrain[0] = row[0].toFloat()
xTrain[1] = row[1].toFloat()
yTrain[0] = row[2].toFloat() // Update this to match your dataset structure
try {
flowerClient.addSample(xTrain, yTrain, true)
} catch (e: ExecutionException) {
throw RuntimeException("Failed to add sample to model", e.cause)
} catch (e: InterruptedException) {
// no-op
}
}
Thanks that was a mistake from my side. Now the toyregression model works