BUG1989/caffe-int8-convert-tools

threshold_distribution中的代码问题

tianylijun opened this issue · 5 comments

def threshold_distribution(distribution, target_bin=128):
"""
Return the best threshold value.
Ref: https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
Args:
distribution: list, activations has been processed by histogram and normalize,size is 2048
target_bin: int, the num of bin that is used by quantize, Int8 default value is 128
Returns:
target_threshold: int, num of bin with the minimum KL
"""
distribution = distribution[1:]
length = distribution.size
threshold_sum = sum(distribution[target_bin:])
kl_divergence = np.zeros(length - target_bin)

for threshold in range(target_bin, length):
    sliced_nd_hist = copy.deepcopy(distribution[:threshold])

    # generate reference distribution p
    p = sliced_nd_hist.copy()
    p[threshold-1] += threshold_sum
    threshold_sum = threshold_sum - distribution[threshold]

    # is_nonzeros[k] indicates whether hist[k] is nonzero
    is_nonzeros = (p != 0).astype(np.int64)
    # 
    quantized_bins = np.zeros(target_bin, dtype=np.int64)
    # calculate how many bins should be merged to generate quantized distribution q
    num_merged_bins = sliced_nd_hist.size // target_bin <----- 这里如果不能整除,当sliced_nd_hist.size为128/129/130/...时,num_merged_bins结果都一样,这里其实只对sliced_nd_hist.size为128的倍数是才有意义 ?
    
    # merge hist into num_quantized_bins bins
    for j in range(target_bin):
        start = j * num_merged_bins
        stop = start + num_merged_bins
        quantized_bins[j] = sliced_nd_hist[start:stop].sum()
    quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum() <-----quantized_bins[-1] 意义是啥,是不是写错了,应该是quantized_bins[target_bin-1] ?

这里的处理是挺糟糕了,当初为了计算阈值速度更快,就借鉴了MxNET的这部分实现

@tianylijun

quantized_bins[-1] 意义是啥,是不是写错了,应该是quantized_bins[target_bin-1]

这里的quantized_bins[-1]不就是quantized_bins[target_bin-1]么

@BUG1989

distribution = distribution[1:]

想问下为什么要去掉distribution[0]啊?

@z13974509906
distribution[0]是0值的统计,不需要量化

虫叔,0值如果不需要量化的话,后面quantized_bins的长度应该是127把?还是说如果是负数的话可以到-128,所以还是按照128来算了