jun-fang/PWLQ

Why is the bit settings in PWLQ so strange?

Opened this issue · 1 comments

In pwlq.py, from lines 60 to 67, if the quant_bit of the middle area is set to bits, the quant_bits of tail_neg and tail_pos are set to bits-1 respectively.

## option 2: non-overlapping
    if pw_opt == 2:
        qw_tail_neg = uniform_affine_quantizer(w, 
            bits=bits-1, scale_bits=scale_bits, minv=-abs_max, maxv=-break_point)
        qw_tail_pos = uniform_affine_quantizer(w, 
            bits=bits-1, scale_bits=scale_bits, minv=break_point, maxv=abs_max)
        qw_middle = uniform_symmetric_quantizer(w, 
            bits=bits, scale_bits=scale_bits, minv=-break_point, maxv=break_point)
    
        qw = torch.where(-break_point < w, qw_middle, qw_tail_neg)
        qw = torch.where(break_point > w, qw, qw_tail_pos)

Won't there be a total num_levels of the full range become 2 * 2 ** bit? Is this right? Or is there anything wrong with my understanding?

You are right, I'm confused with it when I see the code for the first time. But you can look at the paper in 3.2, the last paragraph says, We emphasize that b-bit PWLQ represents FP32 values into b-bitintegers to support b-bit multiply-accumulate operations, even though in total, it has the same number of quantization levels as (b+1)-bit uniform quantization.