This demo application does image classification from camera output and text classification in the same github repo.
To get device camera output it
uses Android CameraX API. All the logic that works with CameraX is separated
to org.pytorch.demo.vision.AbstractCameraXActivity
class.
void setupCameraX() {
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
final Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));
final ImageAnalysisConfig imageAnalysisConfig =
new ImageAnalysisConfig.Builder()
.setTargetResolution(new Size(224, 224))
.setCallbackHandler(mBackgroundHandler)
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
imageAnalysis.setAnalyzer(
(image, rotationDegrees) -> {
analyzeImage(image, rotationDegrees);
});
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
void analyzeImage(android.media.Image, int rotationDegrees)
Where the analyzeImage
method process the camera output, android.media.Image
.
It uses the
aforementioned TensorImageUtils.imageYUV420CenterCropToFloat32Tensor
method to convert android.media.Image
in YUV420
format to input tensor.
After getting predicted scores from the model it finds top K classes with the highest scores and shows on the UI.
Another example is natural language processing, based on an LSTM model, trained on a reddit comments
dataset. The logic happens
in TextClassificattionActivity
.
Result class names are packaged inside the TorchScript model and initialized just after initial
module initialization. The module has a get_classes
method that returns List[str]
, which can be
called using method Module.runMethod(methodName)
:
mModule = Module.load(moduleFileAbsoluteFilePath);
IValue getClassesOutput = mModule.runMethod("get_classes");
The returned IValue
can be converted to java array of IValue
using IValue.toList()
and
processed to an array of strings using IValue.toStr()
:
IValue[] classesListIValue = getClassesOutput.toList();
String[] moduleClasses = new String[classesListIValue.length];
int i = 0;
for (IValue iv : classesListIValue) {
moduleClasses[i++] = iv.toStr();
}
Entered text is converted to java array of bytes with UTF-8
encoding. Tensor.fromBlobUnsigned
creates tensor of dtype=uint8
from that array of bytes.
byte[] bytes = text.getBytes(Charset.forName("UTF-8"));
final long[] shape = new long[]{1, bytes.length};
final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);
Running inference of the model is similar to previous examples:
Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()
After that, the code processes the output, finding classes with the highest scores.