
Linear decomposition toolkit for neural network.

PyDec is a linear decomposition toolkit for neural network based on PyTorch, which can decompose the tensor in the forward process into given components with a small amount of code. The result of decomposition can be applied to tasks such as attribution analysis.


  • Fast. Compute decomposition in foward process and benefit from GPU acceleration.
  • Run once, decompose anywhere. Obtain the decomposition of all hidden states (if you saved them) in forward propagation.
  • Applicable to networks such as Transformer, CNN and RNN.


Requirements and Installation

  • PyTorch version >= 1.11.0
  • Python version >= 3.7
  • To install PyDec and develop locally:
git clone https://github.com/njunlp/pydec
cd pydec
pip install --editable ./
  • To install the latest stable release:
pip install pydec

Getting Started

Example: deompose a tiny network

As a simple example, here's a very simple model with two linear layers and an activation function. We'll create an instance of it and get the decomposition of the output:

import torch

class TinyModel(torch.nn.Module):
    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(4, 10)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

tinymodel = TinyModel()

Given an input x, the output of the model is:

x = torch.rand(4)

print("Input tensor:")

print("\n\nOutput tensor:")


Input tensor:
tensor([0.7023, 0.3492, 0.7771, 0.0157])

Output tensor:
tensor([0.2751, 0.3626], grad_fn=<AddBackward0>)

To decompose the output, just input the Composition initialized from x:

c = pydec.zeros(x.size(), c_num=x.size(0))
c = pydec.diagonal_init(c, src=x, dim=0)

print("Input composition:")

c_out = tinymodel(c)

print("\n\nOutput composition:")


Input composition:
    tensor([0.7023, 0.0000, 0.0000, 0.0000]),
    tensor([0.0000, 0.3492, 0.0000, 0.0000]),
    tensor([0.0000, 0.0000, 0.7771, 0.0000]),
    tensor([0.0000, 0.0000, 0.0000, 0.0157]),
    tensor([0., 0., 0., 0.])}

Output composition:
    tensor([-0.0418, -0.0296]),
    tensor([0.0566, 0.0332]),
    tensor([0.1093, 0.1147]),
    tensor([ 0.0015, -0.0018]),
    tensor([0.1497, 0.2461]),

Each component of the output composition represents the contribution of each feature in x to the output. Summing each component yields the tensor of original output:

print("Sum of each component:")


Sum of each component:
tensor([0.2751, 0.3626], grad_fn=<AddBackward0>)


The full documentation contains examples of implementations on real-world models, tutorials, notes and Python API descriptions.

Linear Decomposition Theory

To understand the principles and theories behind PyDec, see our paper Local Interpretation of Transformer Based on Linear Decomposition.