FedCampus/FedKit

Toy regression model Android inference

Deepakpy opened this issue · 8 comments

Hello i recently tried to use Android-kotlin example with the toy regression dataset and model. However there were error related to tensorflow.
Error:

2023-08-07 20:49:15.675 13218-13574 System.err              flwr.android_client                  W  java.lang.IllegalStateException: Internal error: Failed to run on the given Interpreter: tensorflow/lite/kernels/transpose.cc:59 op_context->perm->dims->data[0] != dims (2 != 3)
2023-08-07 20:49:15.676 13218-13574 System.err              flwr.android_client                  W  Node number 56 (TRANSPOSE) failed to prepare.
2023-08-07 20:49:15.677 13218-13574 System.err              flwr.android_client                  W      at org.tensorflow.lite.NativeSignatureRunnerWrapper.nativeInvoke(Native Method)
2023-08-07 20:49:15.686 13218-13574 System.err              flwr.android_client                  W      at org.tensorflow.lite.NativeSignatureRunnerWrapper.invoke(NativeSignatureRunnerWrapper.java:93)
2023-08-07 20:49:15.687 13218-13574 System.err              flwr.android_client                  W      at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:213)
2023-08-07 20:49:15.687 13218-13574 System.err              flwr.android_client                  W      at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:253)
2023-08-07 20:49:15.688 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient.runSignatureLocked(FlowerClient.kt:211)
2023-08-07 20:49:15.689 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient.training(FlowerClient.kt:162)
2023-08-07 20:49:15.689 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient.access$training(FlowerClient.kt:22)
2023-08-07 20:49:15.690 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient$trainOneEpoch$1.invoke(FlowerClient.kt:144)
2023-08-07 20:49:15.704 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient$trainOneEpoch$1.invoke(FlowerClient.kt:141)
2023-08-07 20:49:15.705 13218-13574 System.err              flwr.android_client                  W      at kotlin.sequences.TransformingSequence$iterator$1.next(Sequences.kt:210)
2023-08-07 20:49:15.705 13218-13574 System.err              flwr.android_client                  W      at kotlin.sequences.SequencesKt___SequencesKt.toCollection(_Sequences.kt:787)
2023-08-07 20:49:15.705 13218-13574 System.err              flwr.android_client                  W      at kotlin.sequences.SequencesKt___SequencesKt.toMutableList(_Sequences.kt:817)
2023-08-07 20:49:15.705 13218-13574 System.err              flwr.android_client                  W      at kotlin.sequences.SequencesKt___SequencesKt.toList(_Sequences.kt:808)
2023-08-07 20:49:15.706 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient.trainOneEpoch(FlowerClient.kt:145)
2023-08-07 20:49:15.706 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient.fit(FlowerClient.kt:96)
2023-08-07 20:49:15.706 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerClient.fit$default(FlowerClient.kt:89)
2023-08-07 20:49:15.706 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerServiceRunnable.handleFitIns(FlowerServiceRunnable.kt:95)
2023-08-07 20:49:15.706 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerServiceRunnable.handleMessage(FlowerServiceRunnable.kt:64)
2023-08-07 20:49:15.706 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:41)
2023-08-07 20:49:15.707 13218-13574 System.err              flwr.android_client                  W      at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:38)
2023-08-07 20:49:15.707 13218-13574 System.err              flwr.android_client                  W      at io.grpc.stub.ClientCalls$StreamObserverToCallListenerAdapter.onMessage(ClientCalls.java:478)
2023-08-07 20:49:15.707 13218-13574 System.err              flwr.android_client                  W      at io.grpc.internal.DelayedClientCall$DelayedListener.onMessage(DelayedClientCall.java:473)
2023-08-07 20:49:15.707 13218-13574 System.err              flwr.android_client                  W      at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInternal(ClientCallImpl.java:660)
2023-08-07 20:49:15.707 13218-13574 System.err              flwr.android_client                  W      at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInContext(ClientCallImpl.java:647)
2023-08-07 20:49:15.707 13218-13574 System.err              flwr.android_client                  W      at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
2023-08-07 20:49:15.708 13218-13574 System.err              flwr.android_client                  W      at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
2023-08-07 20:49:15.708 13218-13574 System.err              flwr.android_client                  W      at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1137)
2023-08-07 20:49:15.708 13218-13574 System.err              flwr.android_client                  W      at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:637)
2023-08-07 20:49:15.708 13218-13574 System.err              flwr.android_client                  W      at java.lang.Thread.run(Thread.java:1012)

The data is being loaded using below function:

 suspend fun loadDatareg(
     context: Context,
     flowerClient: FlowerClient<Float2DArray, FloatArray>
 ) {
     val csvFile = "reg_data/data.csv"
     val inputStream = context.assets.open(csvFile)
     val reader = CSVReader(InputStreamReader(inputStream))
 
     val rows = reader.readAll()
     reader.close()
 
     val xTrain = Array(rows.size - 1) { FloatArray(2) } // -1 because skipping the header row
     val yTrain = FloatArray(rows.size - 1)
 
     for (i in 1 until rows.size) {
         val row = rows[i]
         xTrain[i - 1][0] = row[0].toFloat()
         xTrain[i - 1][1] = row[1].toFloat()
         yTrain[i - 1] = row[2].toFloat() // Update this to match your dataset structure
     }
 
     try {
         flowerClient.addSample(xTrain, yTrain, true)
     } catch (e: ExecutionException) {
         throw RuntimeException("Failed to add sample to model", e.cause)
     } catch (e: InterruptedException) {
         // no-op
     }
 
 //    return Pair(xTrain, yTrain)
 }

This is the client:

 private fun createFlowerClient0() {
     val buffer1 = loadMappedAssetFile(this, "model/toy_regression.tflite")
     val layersSizes = intArrayOf(8,4)
     val sampleSpec = SampleSpec<Float2DArray, FloatArray>(
         { it.toTypedArray() },
         { it.toTypedArray() },
         { Array(it) { FloatArray(1) } },
         ::maxSquaredErrorLoss,
         ::placeholderAccuracy,
     )
     flowerClient1 = FlowerClient(buffer1, layersSizes, sampleSpec)
 }

Any advice on what might be going wrong here while loading the model.

The features (xTrain) should be FloatArray because they are simply arrays of 2 numbers. That is, changing SampleSpec<Float2DArray, FloatArray> to SampleSpec<FloatArray, FloatArray>, and the same for FlowerClient. yTrain should be FloatArray(1). You also need to call flowerClient.addSample in a loop.

Please try the above. Hope it works.

I could see that each input value is loaded as an array of size 11 which is the number of times the loop runs
image
Is this normal ?

Also it throws error in the eval part:

2023-08-09 09:16:37.924  8019-8199  Flower Service Runnable flwr.android_client                  D  Handling EvaluateIns
2023-08-09 09:16:37.929  8019-8199  System.err              flwr.android_client                  W  java.lang.IllegalArgumentException: Array lengths cannot be 0.
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at org.tensorflow.lite.TensorImpl.computeNumDimensions(TensorImpl.java:358)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at org.tensorflow.lite.TensorImpl.computeShapeOf(TensorImpl.java:324)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at org.tensorflow.lite.TensorImpl.getInputShapeIfDifferent(TensorImpl.java:253)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:242)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:194)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:253)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerClient.runSignatureLocked(FlowerClient.kt:203)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerClient.inference(FlowerClient.kt:119)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerClient.evaluate(FlowerClient.kt:105)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerServiceRunnable.handleEvaluateIns(FlowerServiceRunnable.kt:109)
2023-08-09 09:16:37.931  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerServiceRunnable.handleMessage(FlowerServiceRunnable.kt:66)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:41)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at dev.flower.flower_tflite.FlowerServiceRunnable$requestObserver$1.onNext(FlowerServiceRunnable.kt:38)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at io.grpc.stub.ClientCalls$StreamObserverToCallListenerAdapter.onMessage(ClientCalls.java:478)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at io.grpc.internal.DelayedClientCall$DelayedListener.onMessage(DelayedClientCall.java:473)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInternal(ClientCallImpl.java:660)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1MessagesAvailable.runInContext(ClientCallImpl.java:647)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1137)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:637)
2023-08-09 09:16:37.932  8019-8199  System.err              flwr.android_client                  W  	at java.lang.Thread.run(Thread.java:1012)

Changes made:

suspend fun loadDatareg(
    context: Context,
    flowerClient: FlowerClient<FloatArray, FloatArray>
) {
    val csvFile = "reg_data/data.csv"
    val inputStream = context.assets.open(csvFile)
    val reader = CSVReader(InputStreamReader(inputStream))

    val rows = reader.readAll()
    reader.close()

    val xTrain = FloatArray(2) // -1 because skipping the header row
    val yTrain = FloatArray(1)

    for (i in 1 until rows.size) {
        val row = rows[i]
        xTrain[0] = row[0].toFloat()
        xTrain[1] = row[1].toFloat()
        yTrain[0] = row[2].toFloat() // Update this to match your dataset structure
        try {

            flowerClient.addSample(xTrain, yTrain, true)
        } catch (e: ExecutionException) {
            throw RuntimeException("Failed to add sample to model", e.cause)
        } catch (e: InterruptedException) {
            // no-op
        }

    }

}

private fun createFlowerClient0() {
        val buffer1 = loadMappedAssetFile(this, "model/toy_regression.tflite")
        val layersSizes = intArrayOf(8,4)
        val sampleSpec = SampleSpec<FloatArray, FloatArray>(
            { it.toTypedArray() },
            { it.toTypedArray() },
            { Array(it) { FloatArray(1) } },
            ::maxSquaredErrorLoss,
            ::placeholderAccuracy,
        )
        flowerClient1 = FlowerClient(buffer1, layersSizes, sampleSpec)
    }

I don't know which data structure your screenshot refers to. But, the content looks fine to me.

You are getting Array lengths cannot be 0 from TFLite in evaluation because your test dataset is empty.

For now, a workaround is to add your data also to the test dataset by flowerClient.addSample(xTrain, yTrain, false).

flowerClient.addSample(xTrain, yTrain, false)

My assumption was it should have below data loaded
7773,296.816,6765
9253,317.181,7923
15680,567.808,15933
6506,199.645,6526
14854,570.719,12374
6533,241.307,5358
6126,233.927,4373
16660,655.881,13995
13493,501.036,11339
12338,518.152,9174
1004,35.83,716

But it only has copied the last row values to the entire array

I expected it to have something like :
bottlenecks = {float[11][]@24579}
0 = {float[2]@24551} [7773,296.816]
1 = {float[2]@24551} [9253,317.181]
2 = {float[2]@24551} [15680,567.808]
3 = {float[2]@24551} [6506,199.645]
4 = {float[2]@24551} [14854,570.719]
5 = {float[2]@24551} [6533,241.307]
6 = {float[2]@24551} [6126,233.927]
7 = {float[2]@24551} [16660,655.881]
8 = {float[2]@24551} [13493,501.036]
9 = {float[2]@24551} [12338,518.152]
10 = {float[2]@24551} [1004.0, 35.83]

The actual values in the bottlenecks are :
bottlenecks = {float[11][]@24579}
0 = {float[2]@24551} [1004.0, 35.83]
1 = {float[2]@24551} [1004.0, 35.83]
2 = {float[2]@24551} [1004.0, 35.83]
3 = {float[2]@24551} [1004.0, 35.83]
4 = {float[2]@24551} [1004.0, 35.83]
5 = {float[2]@24551} [1004.0, 35.83]
6 = {float[2]@24551} [1004.0, 35.83]
7 = {float[2]@24551} [1004.0, 35.83]
8 = {float[2]@24551} [1004.0, 35.83]
9 = {float[2]@24551} [1004.0, 35.83]
10 = {float[2]@24551} [1004.0, 35.83]

I don't know which data structure your screenshot refers to. But, the content looks fine to me.

the data structure it loads into the traingsamples:
trainingSamples = {ArrayList@24544} size = 11
0 = {Sample@24752} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
bottleneck = {float[2]@24551} [1004.0, 35.83]
bottleneck {java.lang.Object}
label = {float[1]@24774} [716.0]
label {java.lang.Object}
shadow$klass = {Class@22767} "class dev.flower.flower_tflite.Sample"
shadow$monitor = 0
1 = {Sample@24753} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
bottleneck = {float[2]@24551} [1004.0, 35.83]
bottleneck {java.lang.Object}
label = {float[1]@24774} [716.0]
label {java.lang.Object}
shadow$klass = {Class@22767} "class dev.flower.flower_tflite.Sample"
shadow$monitor = 0
2 = {Sample@24754} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)
3 = {Sample@24755} Sample(bottleneck=[F@8151c1d, label=[F@afbfe92)

all the indexes have the last value of the for loop instead of having all the values.

fun addSample(
bottleneck: X, label: Y, isTraining: Boolean
) {
val samples = if (isTraining) trainingSamples else testSamples
val lock = if (isTraining) trainSampleLock else testSampleLock
lock.write {
samples.add(Sample(bottleneck, label))
}
}

this function along with adding a new value. it is also overwriting the previous array values.

You had a common programming mistake about "pass by reference." This should fix it:

-   val xTrain = FloatArray(2)
-   val yTrain = FloatArray(1)

    for (i in 1 until rows.size) {
+       val xTrain = FloatArray(2)
+       val yTrain = FloatArray(1)
        val row = rows[i]
        xTrain[0] = row[0].toFloat()
        xTrain[1] = row[1].toFloat()
        yTrain[0] = row[2].toFloat() // Update this to match your dataset structure
        try {

            flowerClient.addSample(xTrain, yTrain, true)
        } catch (e: ExecutionException) {
            throw RuntimeException("Failed to add sample to model", e.cause)
        } catch (e: InterruptedException) {
            // no-op
        }

    }

Thanks that was a mistake from my side. Now the toyregression model works