EmergentOrder/onnx-scala

Is it possible to pass batch size dynamically?

novakov-alexey-zz opened this issue · 2 comments

My batch size is 32, but the very last batch can be smaller, for example 27 images.
Is it possible to make batch as a variable and pass it to model inference method?

val batch = calcBatchSize()
// below line won't compile
val out = model.fullModel[Float, "ImageClassification", "Batch" ##: "Features" ##: TSNil, batch #: 512 #: SNil](
  Tuple(input)
)

First off, the type params here are for the output. So it shouldn't be "Features" but (predicted) "Class" (or whatever your outputs are).

You can represent typed tensors with symbolic dimensions (i.e. for dynamic batch size) like so:

def input: Tensor[Float, Tuple3["ImageClassification", "Batch" ##: "Class" ##: TSNil, Dimension #: 512 #: SNil]] = ???

However, the fullModel call and the fine-grained ops will not work directly on such tensors, as they need to get the value of the dimension via implicits (at compile-time).

You can either:
A) pad your inputs to the correct (static) dimension size or
B) Cast an Int to a Dimension:

     val batch: Dimension = calcBatchSize().asInstanceOf[Dimension]
     val data = Array.fill(batch*3*224*224){42f}
     
     //In NCHW tensor image format
     val shape =                    batch     #:     3      #:    224    #: 224     #: SNil
     val tensorShapeDenotation = "Batch" ##: "Features" ##: "Height" ##: "Width" ##: TSNil
     
     val tensorDenotation: String & Singleton = "Image"
     
     val imageTens = Tensor(data,tensorDenotation,tensorShapeDenotation,shape)
     
     val out = model.fullModel[Float, "ImageClassification", "Batch" ##: "Class" ##: TSNil, batch.type #: 1000 #: SNil](Tuple(imageTens))

For this to work, the ONNX model has to represent the dimension we are concerned with correctly (as a symbolic dimension), just as the static dimensions in the code must match those in the model.

I am working on Transfer-learning project, where last layer is deleted, so that output notation is "Features" in my case :-), but thank you for pointing out.

Option B is working fine for me. Thank your for answer.