|
setContentView(R.layout.activity_main); |
|
|
|
Bitmap bitmap = null; |
|
Module module = null; |
|
try { |
|
// creating bitmap from packaged into app android asset 'image.jpg', |
|
// app/src/main/assets/image.jpg |
|
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg")); |
|
// loading serialized torchscript module from packaged into app android asset model.pt, |
|
// app/src/model/assets/model.pt |
|
module = LiteModuleLoader.load(assetFilePath(this, "model.pt")); |
|
} catch (IOException e) { |
|
Log.e("PytorchHelloWorld", "Error reading assets", e); |
|
finish(); |
|
} |
|
|
|
// showing image on UI |
|
ImageView imageView = findViewById(R.id.image); |
|
imageView.setImageBitmap(bitmap); |
|
|
|
// preparing input tensor |
|
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, |
|
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST); |
|
|
|
// running the model |
|
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); |
|
|
|
// getting tensor content as java array of floats |
|
final float[] scores = outputTensor.getDataAsFloatArray(); |
|
|
|
// searching for the index with maximum score |
|
float maxScore = -Float.MAX_VALUE; |
|
int maxScoreIdx = -1; |
|
for (int i = 0; i < scores.length; i++) { |
|
if (scores[i] > maxScore) { |
|
maxScore = scores[i]; |
|
maxScoreIdx = i; |
|
} |
|
} |