threshold_distribution中的代码问题
tianylijun opened this issue · 5 comments
def threshold_distribution(distribution, target_bin=128):
"""
Return the best threshold value.
Ref: https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
Args:
distribution: list, activations has been processed by histogram and normalize,size is 2048
target_bin: int, the num of bin that is used by quantize, Int8 default value is 128
Returns:
target_threshold: int, num of bin with the minimum KL
"""
distribution = distribution[1:]
length = distribution.size
threshold_sum = sum(distribution[target_bin:])
kl_divergence = np.zeros(length - target_bin)
for threshold in range(target_bin, length):
sliced_nd_hist = copy.deepcopy(distribution[:threshold])
# generate reference distribution p
p = sliced_nd_hist.copy()
p[threshold-1] += threshold_sum
threshold_sum = threshold_sum - distribution[threshold]
# is_nonzeros[k] indicates whether hist[k] is nonzero
is_nonzeros = (p != 0).astype(np.int64)
#
quantized_bins = np.zeros(target_bin, dtype=np.int64)
# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = sliced_nd_hist.size // target_bin <----- 这里如果不能整除,当sliced_nd_hist.size为128/129/130/...时,num_merged_bins结果都一样,这里其实只对sliced_nd_hist.size为128的倍数是才有意义 ?
# merge hist into num_quantized_bins bins
for j in range(target_bin):
start = j * num_merged_bins
stop = start + num_merged_bins
quantized_bins[j] = sliced_nd_hist[start:stop].sum()
quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum() <-----quantized_bins[-1] 意义是啥,是不是写错了,应该是quantized_bins[target_bin-1] ?
这里的处理是挺糟糕了,当初为了计算阈值速度更快,就借鉴了MxNET的这部分实现
quantized_bins[-1] 意义是啥,是不是写错了,应该是quantized_bins[target_bin-1]
这里的quantized_bins[-1]不就是quantized_bins[target_bin-1]么
@z13974509906
distribution[0]是0值的统计,不需要量化
虫叔,0值如果不需要量化的话,后面quantized_bins的长度应该是127把?还是说如果是负数的话可以到-128,所以还是按照128来算了