/TorchIntegral

Integral Neural Networks in PyTorch

Primary LanguagePythonApache License 2.0Apache-2.0

TorchIntegral: Official implementation of the "Integral Neural Networks" CVPR2023.

Tux, the Linux mascot

Table of contents

This library is official implementation of "Integral Neural Networks" paper in Pytorch.

Tux, the Linux mascot

Requirements

  • pytorch 2.0+
  • torchvision
  • numpy
  • scipy
  • Cython
  • catalyst
  • pytorchcv

Installation

git clone https://github.com/TheStageAI/TorchIntegral.git
pip install TorchIntegral/

or

pip install git+https://github.com/TheStageAI/TorchIntegral.git

Usage examples

Convert your model to integral model:

import torch
import torch_integral as inn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
wrapper = inn.IntegralWrapper(init_from_discrete=True)

# Specify continuous dimensions which you want to prune
continuous_dims = {
    "layer4.0.conv1.weight": [0],
    "layer4.1.conv1.weight": [0, 1]
}

# Convert to integral model
inn_model = wrapper(model, (1, 3, 224, 224), continuous_dims)

Set distribution for random number of integration points:

inn_model.groups[0].reset_distribution(inn.UniformDistribution(8, 16))
inn_model.groups[1].reset_distribution(inn.UniformDistribution(16, 48))

Train integral model using vanilla training methods. Ones the model is trained resample (prune) it to arbitrary size:

inn_model.groups[0].resize(12)
inn_model.groups[1].resize(16)

After resampling of the integral model it can be evaluated as usual discrete model:

discrete_model = inn_model.tranform_to_discrete()

One can use torch_integral.graph to build dependecy graph for structured pruning:

from torch_integral import IntegralTracer

groups = IntegralTracer(model, example_input=(3, 28, 28)).build_groups()
pruner = L1Pruner()

for group in groups:
    pruner(group, 0.5)

Integrating a function using numerical quadratures:

from torch_integral.quadrature import TrapezoidalQuadrature, integrate
import torch

def function(grid):
    return torch.sin(10 * grid[0])

quadrature = TrapezoidalQuadrature(integration_dims=[0])
grid = [torch.linspace(0, 3.1415, 100)]
integrate(quadrature, function, grid)

More examples can be found in examples directory.

Frequently asked questions

See FAQ for frequently asked questions.

TODO

  • Add models zoo.
  • Fix tracing of reshape and view operations.
  • Add integral self attention and batch norm layers.
  • Fix serialization of parametrized model.

Further research

Here is some ideas for community to continue this research:

  • Weight function parametrization with SiReN.
  • Combine INNs and neural ODE.
  • For more flexible weight tensor parametrization let the function have breakpoints.
  • Multiple TSP for total variation minimization task.
  • Due to lower total variation of INNs it's interesting to check resistance of such models to adversarial attacks.
  • Train integral GANs.
  • Research different numerical quadratures, for example Monte-Carlo integration or Bayesian quadrature.

References

If this work was useful for you, please cite it with:

@InProceedings{Solodskikh_2023_CVPR,
    author    = {Solodskikh, Kirill and Kurbanov, Azim and Aydarkhanov, Ruslan and Zhelavskaya, Irina and Parfenov, Yury and Song, Dehua and Lefkimmiatis, Stamatios},
    title     = {Integral Neural Networks},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {16113-16122}
}

and

@misc{TorchIntegral,
	author={Kurbanov A., Solodskikh K.},
	title={TorchIntegral},
	year={2023},
	url={https://github.com/TheStageAI/TorchIntegral},
}