/TorchIntegral

Integral Neural Networks in PyTorch

Primary LanguagePythonApache License 2.0Apache-2.0

TorchIntegral: Official implementation of the "Integral Neural Networks" CVPR2023.

Tux, the Linux mascot

HIRING!

We are looking for a great DL researchers and DL software engineers to join our journey on building the most efficient Neural Networks through the deep diving into physical and math intuition behind.

Job description: AI Engineer

Send your CV to: hr@thestage.ai

Email subject: 'DL researcher' or 'DL SE'

Thestage.ai site: thestage.ai

Table of contents

This library is official implementation of "Integral Neural Networks" paper in Pytorch.

Tux, the Linux mascot

Requirements

  • pytorch 2.0+
  • torchvision
  • numpy
  • scipy
  • Cython
  • catalyst
  • pytorchcv

Installation

git clone https://github.com/TheStageAI/TorchIntegral.git
pip install TorchIntegral/

or

pip install git+https://github.com/TheStageAI/TorchIntegral.git

Usage examples

Convert your model to integral model:

import torch
import torch_integral as inn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
wrapper = inn.IntegralWrapper(init_from_discrete=True)

# Specify continuous dimensions which you want to prune
continuous_dims = {
    "layer4.0.conv1.weight": [0],
    "layer4.1.conv1.weight": [0, 1]
}

# Convert to integral model
inn_model = wrapper(model, (1, 3, 224, 224), continuous_dims)

Set distribution for random number of integration points:

inn_model.groups[0].reset_distribution(inn.UniformDistribution(8, 16))
inn_model.groups[1].reset_distribution(inn.UniformDistribution(16, 48))

Train integral model using vanilla training methods. Ones the model is trained resample (prune) it to arbitrary size:

inn_model.groups[0].resize(12)
inn_model.groups[1].resize(16)

After resampling of the integral model it can be evaluated as usual discrete model:

discrete_model = inn_model.get_unparametrized_model()

One can use torch_integral.graph to build dependecy graph for structured pruning:

from torch_integral import IntegralTracer

groups = IntegralTracer(model, example_input=(3, 28, 28)).build_groups()
pruner = L1Pruner()

for group in groups:
    pruner(group, 0.5)

Integrating a function using numerical quadratures:

from torch_integral.quadrature import TrapezoidalQuadrature, integrate
import torch

def function(grid):
    return torch.sin(10 * grid[0])

quadrature = TrapezoidalQuadrature(integration_dims=[0])
grid = [torch.linspace(0, 3.1415, 100)]
integrate(quadrature, function, grid)

More examples can be found in examples directory.

Frequently asked questions

See FAQ for frequently asked questions.

TODO

  • Add models zoo.
  • Fix tracing of reshape and view operations.
  • Add integral self attention and batch norm layers.

Further research

Here is some ideas for community to continue this research:

  • Weight function parametrization with SiReN.
  • Combine INNs and neural ODE.
  • For more flexible weight tensor parametrization let the function have breakpoints.
  • Multiple TSP for total variation minimization task.
  • Due to lower total variation of INNs it's interesting to check resistance of such models to adversarial attacks.
  • Train integral GANs.
  • Research different numerical quadratures, for example Monte-Carlo integration or Bayesian quadrature.

References

If this work was useful for you, please cite it with:

@InProceedings{Solodskikh_2023_CVPR,
    author    = {Solodskikh, Kirill and Kurbanov, Azim and Aydarkhanov, Ruslan and Zhelavskaya, Irina and Parfenov, Yury and Song, Dehua and Lefkimmiatis, Stamatios},
    title     = {Integral Neural Networks},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {16113-16122}
}

and

@misc{TorchIntegral,
	author={Kurbanov A., Solodskikh K.},
	title={TorchIntegral},
	year={2023},
	url={https://github.com/TheStageAI/TorchIntegral},
}