/flink-onnx-pytorch

Streaming machine learning using PyTorch, Flink, and ONNX

Primary LanguageScalaMIT LicenseMIT

Scala CI

PyTorch on Flink with ONNX

This repository shows an example of using Apache Flink to serve an ONNX model that was created by PyTorch. While the initial model is simple, it shows the process by which a machine learning model can be created in Python, and then served by a Scala program written for streaming with Flink.

Training in PyTorch

The training/ directory provides an example Jupyter Notebook that uses PyTorch to create a simple network. The network is exported to ONNX through PyTorch. The network takes a variable length floating point tensor and adds a constant offset. This network could be easily replaced with a fully trained neural network.

Packaging the JAR

The ONNX model is packaged as a resource, so that it can be distributed to the Flink TaskManagers in the same JAR as the code.

The Flink application follows the standard template for SBT. The JAR can be built using sbt as follows:

sbt clean assembly

The Scala code is also tested to ensure quality, and those tests are run by sbt. These tests exercise the inference of the ONNX model.

The contents of the JAR can be inspected to see the included ONNX model and the onnxruntime library.

jar tf ./target/scala-2.11/flink-onnx-pytorch-assembly-0.0.1.jar

Serving with onnxruntime

The ONNX model is served using the onnxruntime library. The Java API is used, which provides a similar interface to Python. The OrtModel class is added as a convenience wrapper to handle loading from the resources directory.

Running the model inference is wrapped by the AddFive class, which extends the RichMapFunction. This allows the open and close methods to handle the model loading and closing. The map method runs the input value through the ONNX model.

The types need to be handled properly, since fractional values default to Double in Scala unless specifically identified.

Streaming in Flink

With the JAR built, the Flink job can be submitted. The job is written over a static dataset of numbers as a simple example. This could be easily replaced by another source. The job outputs the results to stdout for simplicity, which could also be replaced by another sink. The DataStream API is used in this example.

To run the JAR on a local cluster, make sure to start it first.

./flink-1.13.0/bin/start-cluster.sh

The com.datacolin.Job can be submitted to the cluster through the CLI:

./flink-1.13.0/bin/flink run -c com.datacolin.Job ./flink-onnx-pytorch/target/scala-2.11/flink-onnx-pytorch-assembly-0.0.1.jar

The local cluster will provide output in the log/ directory, which can be watched:

tail -f ./flink-1.13.0/log/*.out

The static dataset should be output with the added offset.

Future directions

This demonstration provides the basic working pieces to get a Python machine learning model running in a streaming Scala application in Flink. There are a number of interesting real-time applications to try next.