/einops

Deep learning operations reinvented (for pytorch, tensorflow, chainer, gluon and others)

Primary LanguagePythonMIT LicenseMIT

einops package examples
This video in better quality.

einops

Build Status PyPI version

Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and others.

Contents

Tutorial / Documentation

Tutorial is the most convenient way to see einops in action (and right now works as a documentation)

Installation

Plain and simple:

pip install einops

einops has no mandatory dependencies (code examples also require jupyter, pillow + backends). To obtain the latest github version

pip install https://github.com/arogozhnikov/einops/archive/master.zip

API

einops has minimalistic and powerful API.

Two operations provided (see einops tutorial for examples)

from einops import rearrange, reduce, repeat
# rearrange elements according to the pattern
output_tensor = rearrange(input_tensor, 't b c -> b c t')
# combine rearrangement and reduction
output_tensor = reduce(input_tensor, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)
# copy along a new axis 
output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)

And two corresponding layers (einops keeps separate version for each framework) with the same API.

from einops.layers.chainer import Rearrange, Reduce
from einops.layers.gluon import Rearrange, Reduce
from einops.layers.keras import Rearrange, Reduce
from einops.layers.torch import Rearrange, Reduce
from einops.layers.tensorflow import Rearrange, Reduce

Layers behave similarly to operations and have same parameters (for the exception of first argument, which is passed during call)

layer = Rearrange(pattern, **axes_lengths)
layer = Reduce(pattern, reduction, **axes_lengths)

# apply created layer to a tensor / variable
x = layer(x)

Example of using layers within a model:

# example given for pytorch, but code in other frameworks is almost identical  
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrange

model = Sequential(
    Conv2d(3, 6, kernel_size=5),
    MaxPool2d(kernel_size=2),
    Conv2d(6, 16, kernel_size=5),
    MaxPool2d(kernel_size=2),
    # flattening
    Rearrange('b c h w -> b (c h w)'),  
    Linear(16*5*5, 120), 
    ReLU(),
    Linear(120, 10), 
)

Additionally two auxiliary functions provided

from einops import asnumpy, parse_shape
# einops.asnumpy converts tensors of imperative frameworks to numpy
numpy_tensor = asnumpy(input_tensor)
# einops.parse_shape gives a shape of axes of interest 
parse_shape(input_tensor, 'batch _ h w') # e.g {'batch': 64, 'h': 128, 'w': 160}

Naming

einops stays for Einstein-Inspired Notation for operations (though "Einstein operations" is more attractive and easier to remember).

Notation was loosely inspired by Einstein summation (in particular by numpy.einsum operation).

Why using einops notation

Semantic information (being verbose in expectations)

y = x.view(x.shape[0], -1)
y = rearrange(x, 'b c h w -> b (c h w)')

while these two lines are doing the same job in some context, second one provides information about input and output. In other words, einops focuses on interface: what is input and output, not how output is computed.

The next operation looks similar:

y = rearrange(x, 'time c h w -> time (c h w)')

But it gives reader a hint: this is not an independent batch of images we are processing, but rather a sequence (video).

Semantic information makes code easier to read and maintain.

More checks

Reconsider the same example:

y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19)
y = rearrange(x, 'b c h w -> b (c h w)')

second line checks that input has four dimensions, but you can also specify particular dimensions. That's opposed to just writing comments about shapes since comments don't work and don't prevent mistakes as we know

y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19)
y = rearrange(x, 'b c h w -> b (c h w)', c=256, h=19, w=19)

Result is strictly determined

Below we have at least two ways to define depth-to-space operation

# depth-to-space
rearrange(x, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=2, w2=2)
rearrange(x, 'b c (h h2) (w w2) -> b (h2 w2 c) h w', h2=2, w2=2)

there are at least four more ways to do it. Which one is used by the framework?

These details are ignored, since usually it makes no difference, but it can make a big difference (e.g. if you use grouped convolutions on the next stage), and you'd like to specify this in your code.

Uniformity

reduce(x, 'b c (x dx) -> b c x', 'max', dx=2)
reduce(x, 'b c (x dx) (y dx) -> b c x y', 'max', dx=2, dy=3)
reduce(x, 'b c (x dx) (y dx) (z dz)-> b c x y z', 'max', dx=2, dy=3, dz=4)

These examples demonstrated that we don't use separate operations for 1d/2d/3d pooling, those all are defined in a uniform way.

Space-to-depth and depth-to space are defined in many frameworks. But how about width-to-height?

rearrange(x, 'b c h (w w2) -> b c (h w2) w', w2=2)

Framework independent behavior

Even simple functions are defined differently by different frameworks

y = x.flatten() # or flatten(x)

Suppose x shape was (3, 4, 5), then y has shape ...

  • numpy, cupy, chainer: (60,)
  • keras, tensorflow.layers, mxnet and gluon: (3, 20)
  • pytorch: no such function

Independence of framework terminology

Example: tile vs repeat causes lots of confusion. To copy image along width:

np.tile(image, (1, 2))    # in numpy
image.repeat(1, 2)        # pytorch's repeat ~ numpy's tile

With einops you don't need to decipher which axis was repeated:

repeat(image, 'h w -> h (tile w)', tile=2)  # in numpy
repeat(image, 'h w -> h (tile w)', tile=2)  # in pytorch
repeat(image, 'h w -> h (tile w)', tile=2)  # in tf
repeat(image, 'h w -> h (tile w)', tile=2)  # in jax
repeat(image, 'h w -> h (tile w)', tile=2)  # in mxnet
... (etc.)

Supported frameworks

Einops works with ...

Contributing

Best ways to contribute are

  • share your feedback. Experimental APIs currently require third-party testing.
  • spread the word about einops
  • if you like explaining things, alternative tutorials can be helpful
  • translating examples in languages other than English is also a good idea
  • finally, use einops notation in your papers to strictly define used operations!

Supported python versions

einops works with python 3.5 or later.

There is nothing specific to python 3 in the code, we simply need to move further and the decision is not to support python 2.