wjc852456/pytorch-quant

compressed model size

Opened this issue · 0 comments

Hi, thanks for sharing your work, I use your code to quantify my own u-net model,I have a question to consult you, the size of quantified model(8 bits) is the same as before, Is there a problem with my operation? Here is the code:

parser = argparse.ArgumentParser(description='PyTorch Quantization')
parser.add_argument('--test', type=int, default=1, help='test data distribution')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size for training')
parser.add_argument('--n_sample', type=int, default=10, help='number of samples to infer the scaling factor')
parser.add_argument('--gpu', default="0", help='index of gpus to use')
parser.add_argument('--ngpu', type=int, default=1, help='number of gpus to use')
parser.add_argument('--logdir', default='log/default', help='folder to save to the log')

parser.add_argument('--replace_bn', type=int, default=0, help='decide if replace bn layer')
parser.add_argument('--map_bn', type=int, default=0, help='decide if map bn layer to conv layer')

parser.add_argument('--input_size', type=int, default=224, help='input size of image')
parser.add_argument('--shuffle', type=int, default=1, help='data shuffle')
parser.add_argument('--overflow_rate', type=float, default=0.0, help='overflow rate')

parser.add_argument('--quant_method', default='linear', help='linear|minmax|log|tanh|scale')
parser.add_argument('--param_bits', type=int, default=8, help='bit-width for parameters')
parser.add_argument('--bn_bits', type=int, default=8, help='bit-width for running mean and std')
parser.add_argument('--fwd_bits', type=int, default=8, help='bit-width for layer output')
args = parser.parse_args()

model = UNet().cuda()
model.load_state_dict(torch.load('./unet.pth.tar'))
if args.replace_bn:
quant.replace_bn(model)
if args.map_bn:
quant.bn2conv(model)
print("=================quantize parameters==================")
if args.param_bits < 32:
state_dict = model.state_dict()
state_dict_quant = OrderedDict()
sf_dict = OrderedDict()
for k, v in state_dict.items():
if 'running' in k: # quantize bn layer
if args.bn_bits >=32:
print("Ignoring {}".format(k))
state_dict_quant[k] = v
continue
else:
bits = args.bn_bits
else:
bits = args.param_bits
if args.quant_method == 'linear':
sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
v_quant = quant.linear_quantize(v, sf, bits=bits)
elif args.quant_method == 'log':
v_quant = quant.log_minmax_quantize(v, bits=bits)
elif args.quant_method == 'minmax':
v_quant = quant.min_max_quantize(v, bits=bits)
else:
v_quant = quant.tanh_quantize(v, bits=bits)
state_dict_quant[k] = v_quant
print("k={0:<35}, bits={1:<5}, sf={2:d>}".format(k,bits,sf))
model.load_state_dict(state_dict_quant)
print("======================================================") print("=================quantize activation==================")
if args.fwd_bits < 32:
model = quant.duplicate_model_with_quant(model, bits=args.fwd_bits,
overflow_rate=args.overflow_rate,
counter=args.n_sample,
type=args.quant_method)
print("======================================================")
torch.save(model.state_dict(),'./quantization_model.pth.tar')