Question: apply trainable scale for qdq `linear` and `matmul`
yiliu30 opened this issue · 3 comments
In a quantization scenario where fake quantization is utilized to assess the accuracy of a new algorithm with trainable scale, we can implement it for an eager model by replacing the Linear module with QDQLinear
, as demonstrated below:
class QDQLinear(torch.nn.Module):
def __init__(self, orign_linear: torch.nn.Module) -> None:
super().__init__()
self._orign_linear = original_tensor
self.trainable_scale = torch.nn.Parameter(torch.tensor(1), requires_grad=True)
def qdq_tensor(self, input: torch.Tensor):
# ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
# int_input = q(input)
# qdq_input = dq(int_input)
# return qdq_input
pass
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = self.qdq_tensor(input)
return torch.nn.functional.linear(input, self._orign_linear.weight, self._orign_linear.bias)
### replace all `Linear` with `QDQLinear`
However, some models utilize torch.matmul to perform similar thing as torch.nn.Linear
. We also want to apply the aforementioned QDQ method to torch.matmul
, but this cannot be achieved through module swapping.
We may probably customize a new TorchDispatchMode
to replace all aten.mm
with qdq - aten.mm
to apply qdq to all input tensors of torch.matmul
or torch.nn.Linear
. However, I'm currently unsure how to handle the trainable_scale
. Do you happen to have any suggestions?
Thank you very much!
Hi
If you want to pre-process the input to the Module, I think a module pre forward hook would work?
def qdq_tensor(input):
pass
your_layer.register_forward_pre_hook(qdq_tensor)
Is that enough for you?
Oh, sorry. There are a few errors in the question.
I want to pre-process the module's weight (for weight-only quantization).
class QDQLinear(torch.nn.Module):
def __init__(self, orign_linear: torch.nn.Module) -> None:
super().__init__()
self._orign_linear = original_tensor
self.trainable_scale = torch.nn.Parameter(torch.tensor(1), requires_grad=True)
def qdq_tensor(self, input: torch.Tensor, scale: torch.Tensor):
# ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
# int_input = q(input, scale)
# qdq_input = dq(int_input)
# return qdq_input
pass
def forward(self, input: torch.Tensor) -> torch.Tensor:
qdq_weight= self.qdq_tensor(self._orign_linear.weight, self.trainable_scale) # <---------- q-dq weight
return torch.nn.functional.linear(input, qdq_weight, self._orign_linear.bias)
### replace all `Linear` with `QDQLinear`
If you do want these two params and combine them only when mod.weight is used, I would suggest reparametrization:
from torch.nn.utils.parametrize import register_parametrization
class QDQParam(torch.nn.Module):
def forward(self, orig_linear_weight, scale):
return qdq_tensor(orig_linear_weight, scale)
def right_inverse(self, orig_linear_weight):
return orig_linear_weight, torch.tensor(1)
m = nn.Linear(2, 2)
register_parametrization(m, "weight", QDQParam())
More details at https://pytorch.org/tutorials/intermediate/parametrizations.html