albanD/subclass_zoo

Question: apply trainable scale for qdq `linear` and `matmul`

yiliu30 opened this issue · 3 comments

In a quantization scenario where fake quantization is utilized to assess the accuracy of a new algorithm with trainable scale, we can implement it for an eager model by replacing the Linear module with QDQLinear, as demonstrated below:

class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)
    
    def qdq_tensor(self, input: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = self.qdq_tensor(input)
        return torch.nn.functional.linear(input, self._orign_linear.weight, self._orign_linear.bias)


### replace all `Linear` with `QDQLinear`

However, some models utilize torch.matmul to perform similar thing as torch.nn.Linear. We also want to apply the aforementioned QDQ method to torch.matmul, but this cannot be achieved through module swapping.

We may probably customize a new TorchDispatchMode to replace all aten.mm with qdq - aten.mm to apply qdq to all input tensors of torch.matmul or torch.nn.Linear. However, I'm currently unsure how to handle the trainable_scale. Do you happen to have any suggestions?

Thank you very much!

Hi

If you want to pre-process the input to the Module, I think a module pre forward hook would work?

def qdq_tensor(input):
    pass

your_layer.register_forward_pre_hook(qdq_tensor)

Is that enough for you?

Oh, sorry. There are a few errors in the question.
I want to pre-process the module's weight (for weight-only quantization).

class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)
    
    def qdq_tensor(self, input: torch.Tensor, scale: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input, scale)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        qdq_weight= self.qdq_tensor(self._orign_linear.weight, self.trainable_scale)  # <---------- q-dq weight
        return torch.nn.functional.linear(input, qdq_weight, self._orign_linear.bias)


### replace all `Linear` with `QDQLinear`

If you do want these two params and combine them only when mod.weight is used, I would suggest reparametrization:

from torch.nn.utils.parametrize import register_parametrization

class QDQParam(torch.nn.Module):
    def forward(self, orig_linear_weight, scale):
        return qdq_tensor(orig_linear_weight, scale)

    def right_inverse(self, orig_linear_weight):
        return orig_linear_weight, torch.tensor(1)

m = nn.Linear(2, 2)
register_parametrization(m, "weight", QDQParam())

More details at https://pytorch.org/tutorials/intermediate/parametrizations.html