larq/compute-engine

experimental_enable_bitpacked_activations failure for same padding

lkindrat-xmos opened this issue · 5 comments

I understand that this feature is experimental, but I ran into the following issue: it looks like when two binary convolutions follow each other, an LceQuantize op is injected if the first convolution has same padding. The padding of the second convolution seems to have no effect on the outcome:

import tensorflow as tf
import larq as lq
import larq_compute_engine as lce

def myconv2d(padding):
    return lq.layers.QuantConv2D(
        filters=32,
        kernel_size=3,
        strides=2,
        padding=padding,
        input_quantizer="ste_sign",
        kernel_quantizer="ste_sign",
        kernel_constraint="weight_clip",
    )

def mymodel(first_padding):
    img = tf.keras.layers.Input(shape=(8, 8, 32))
    x = myconv2d(first_padding)(img)
    x = myconv2d("valid")(x)  # padding doesn't matter
    return tf.keras.Model(img, x)

with open("model_ok.tflite", "wb") as f:
    f.write(
        lce.convert_keras_model(
            mymodel("valid"), experimental_enable_bitpacked_activations=True,
        )
    )

with open("model_bug.tflite", "wb") as f:
    f.write(
        lce.convert_keras_model(
            mymodel("same"), experimental_enable_bitpacked_activations=True,
        )
    )

Here are the output flatbuffers:
larq_bitpacked_padding_bug.zip

Package versions:
tensorflow: 2.3.0
larq: 0.10.1 (from pypi)
larq-compute-engine: 0.4.3 (from pypi)

Thanks for the issue :) This is expected behaviour if pad_values=0 is set in the Larq QuantConv2D, as is the default. If pad_values=1 is used then your example should work. The reason we don't perform the correct transformation in the zero-padding case in the converter is that the LCE optimised kernels don't support writing bitpacked output with 'same-zero' padding.

Thanks for the clarification! In this case, would it be a good idea to raise a warning in the converter (i.e. when using experimental_enable_bitpacked_activations with 'same-zero' padding)?

Also, if the reference implementation supports bitpacked output with 'same-zero' padding, shouldn't this cause a runtime error in the optimized implementation instead (and raise a warning as well during conversion)?

Sometime soon we want to make experimental_enable_bitpacked_activations=True turned on by default.

I completely agree that we should work out a good way to raise warnings for models which won't convert in an 'optimal' way -- essentially any eggregious violation of our model optimisation guide. It's slightly complicated for a few reasons:

  • We need to decide where to raise warnings (in Larq, in the LCE converter)?
  • Many of the warnings we could raise are target-dependent, but the flatbuffer files produced by the LCE converter are target-independent.
  • There are different degrees of 'sub-optimal'. In the particular example you raise, when using the LCE optimized kernels on Aarch64, not being able to write bitpacked activations doesn't have a huge measurable latency impact (< 10%).
  • We don't want to spam users with lots of warnings about sub-optimal code, especially as there are legitimate use-cases where a user doesn't care if something will run slowly (for example, the user might be doing research into new BNN architecture designs and using patterns which aren't currently optimised in LCE, such as grouped convolutions, but could be implemented in the future).

^I've copied the above into issue #542.

Passing pad_values=1 worked as you said. Feel free to close this issue, thanks!