apache/tvm

Data-Aware Calibration

ZihengJiang opened this issue · 8 comments

The current calibration algorithm use global_scale to estimate the scale of intermediate result, which is configured by user. For better accuracy loss, we need to implement data-aware calibration algorithm: given a small calibration dataset (like 100 samples), to achieve better scale estimation.

Implementation

Collecting Intermediate Results

To collect intermediate results of every operation in the graph, we may need to implement a Visitor to return every Call as output. The API would be: stats = collect_stats(graph, data)

Modification for SimulatedQuantize

It is kinds of hard to save the mapping from origin operator to the operator after annotation. So we may need to add an option for simulated_quantize, like mode='origin', which denotes that sq will not simulate the rounding error/saturated error, instead, just return the input directly. With this, we can collect the intermediate result of the original graph.

Calibration Procedure

Having the calibration data, we can adjust the scale according to the output of the annotated graph. It is actually an optimization problem. The target can be the KL divergence between outputs of original graph and annotated graph, and the adjustment method can be simply search-based or learning-based. There should be lots of room for exploration.

eqy commented

I am very interested in this problem, and I have a few different approaches I want to try out. However, I wonder if requiring intermediate results to be collected is too labor intensive or unnecessary. In the AutoTVM setting, we use domain specific features as they are powerful and allow for transfer learning. However, transfer learning is not as important in this setting, and it could be that blackbox features (e.g., just the scale configuration) may be enough to build a usable machine learning model (just as knob features are usable in AutoTVM). I think if that turns out to already work then we can save a significant amount of effort. It seems slightly expensive to compute both the original graph and annotated graph every time we want to do try a different calibration configuration, so I think it would be great if we could get away with doing even less work :).

I'll send updates on this approach as I get results. You may have already guessed that performance is important as I am first tuning the workloads required for quantization so trying many calibration configurations is not too slow.

FWIW we've got some experience with this kind of thing, and one observation is that you don't need a sophisticated calibration procedure — that is, simple L2-minimization (i.e. not L \infty i.e. min/max calibration) performs completely fine (on par/better than much more sophisticated algorithms). cf. https://github.com/pytorch/pytorch/blob/master/caffe2/quantization/server/norm_minimization.cc, etc.

Can we have a hard-coded quantile-based simple & quick solution for this feature as a first step? Then more reliable solutions (either search-based, or learning-based) can be added later, so that we can implement this feature in a progressive way.

eqy commented

@ajtulloch Haha, I was about to say that we have independently discovered the same thing. There is some literature that seems to try some more complex (or efficient?) methods (e.g., https://arxiv.org/abs/1810.05488) but I am not sure that the results are better than just plain old mean-squared error.
@liangfu #2753 is a WIP to add part of the support for this; basically the remaining steps, at a high level are:
[ ] profiling activation support (Relay makes this easy and I have an internal branch that does this)
[ ] per-channel scales (needed for MobileNet, possibly DenseNet)

The second part is a bit more involved as it essentially requires setting up/inferring the number of scales we need for each operator and then tacking them to the graph. It seems the data-aware calibration method will also require some API changes as it will be done after a first version of the graph has already been constructed, as it requires profiling the intermediate activations (e.g., to compute L2/MSE).

eqy commented

n#3294 provides an initial implementation of data-aware calibration and per-channel domain scale support

Currently, the support focuses on vision models and likely makes many assumptions that will break down with other types of models. Similarly, part of per-channel scale rely on brittle heuristics, but we expect to improve these over time as we evaluate more models and understand more about the general patterns in quantization.

The previous quantization process used three passes: annotate, "calibrate" (which only calibrated weights), and realize, which materializes concrete domain scale compensations between layers and the required casts between int8 and int32.

At a high level, this PR changes the pass order/structure slightly to:

  1. annotate (same as before, with a few additional properties of the graph captures, such as data layout). Data layout is important for inferring the number of channels and the channel dimension in data/weight tensors.
  2. First calibration. This first calibration step calibrates weights as before, but also marks intermediate activations as outputs of the graph so that they can be profiled.
  3. Profiling/evaluation: the network is executed on a small set of training data (currently on the order of a small mini batch or 32 samples) so that the intermediate activations can be profiled.
  4. Second calibration. We use MSE to choose a domain scale to calibrate each of the profiled activations.
  5. Scale matching (only for per-channel scale). Due to the fact that quantization relies on integer hardware, scales must match in the middle of a kernel. Here, we manipulate the per-channel scales to compensate for any differences after the multiply phase of the convolution and tune them so that scales are matched during accumulation.
  6. realize is the same as before but can handle vector domain scales instead of scalar domain scales.

Some issues with currently include:

  • Inferring the correct number of channels and the channel dimension becomes very difficult with arbitrary data layouts.
  • Inferring the number of channels can also be tricky with a wide variety of operators and operator broadcasting conventions (e.g., how many channels should be given the constant tensor in an add(activation, constant) operator when the constant has shape (1,1000)?). In many cases, more than local information is needed.
  • Concatenating multiple tensors currently simply results in the maximum scales across all tensors being used with per-layer scales.
  • We require a dataset for calibration and a uniform API for passing this data. Currently, this is relatively straightforward for vision models (e.g., we can use MxNet format and MxNet iterators), but it is not clear how we want to support other domains.

Current accuracy results (in flux):

Network per-layer scale per-channel scale
resnet18_v1 0.68636 0.67964
resnet50_v1 0.74358 0.75114
resnet101_v1 0.77042 0.76862
mobilenet1.0 0.6424 0.66812
mobilenetv2_1.0 0.66636 0.33368
densenet121 0.0038 0.73362
inceptionv3 0.70678 0.73626

I have implemented KL divergence based calibration (ported from MXNet) in my branch https://github.com/vinx13/tvm/tree/feature/calibration.
Similar to #3294, it first runs on a few samples and collects profiles. Scales of activation are chosen to minimize KLD between original and quantized distributions.
Scales of weight can be either power-of-2 or taking maximum directly without rounding.
It also allows special handling of bias by introducing a new kind QAnnotateKind.BIAS. Some experiments showed that increasing nbits of bias from 8 to 16 can be helpful. But the optimal number of bits for bias is still unclear.
Currently it stored profile results in memory. This may cause memory issue when #layers or #samples in calibration set grow.
@eqy I plan to adopt to the calibrate interface in #3294 and prepare a PR.
Any comments are welcomed.

#3538 is now merged. @vinx13 @ZihengJiang it would be great to create followup tutorials to demonstrate the usage of quantize pass.

closed by #3538 let us open new thread for tutorials