/hopfield-layers

Hopfield Networks is All You Need

Primary LanguagePythonOtherNOASSERTION

Hopfield Networks is All You Need

Hubert Ramsauer1, Bernhard Schäfl1, Johannes Lehner1, Philipp Seidl1, Michael Widrich1, Lukas Gruber1, Markus Holzleitner1, Milena Pavlović3, 4, Geir Kjetil Sandve4, Victor Greiff3, David Kreil2, Michael Kopp2, Günter Klambauer1, Johannes Brandstetter1, Sepp Hochreiter1, 2

1 ELLIS Unit Linz and LIT AI Lab, Institute for Machine Learning, Johannes Kepler University Linz, Austria
2 Institute of Advanced Research in Artificial Intelligence (IARAI)
3 Department of Immunology, University of Oslo, Norway
4 Department of Informatics, University of Oslo, Norway


Detailed blog post on this paper as well as the necessary background on Hopfield networks at this link.

The transformer and BERT models pushed the performance on NLP tasks to new levels via their attention mechanism. We show that this attention mechanism is the update rule of a modern Hopfield network with continuous states. This new Hopfield network can store exponentially (with the dimension) many patterns,converges with one update, and has exponentially small retrieval errors. The number of stored patterns must be traded off against convergence speed and retrieval error. The new Hopfield network has three types of energy minima (fixed points of the update):

  1. global fixed point averaging over all patterns,
  2. metastable states averaging over a subset of patterns, and
  3. fixed points which store a single pattern.

Transformers learn an attention mechanism by constructing an embedding of patterns and queries into an associative space. Transformer and BERT models operate in their first layers preferably in the global averaging regime, while they operate in higher layers in metastable states. The gradient in transformers is maximal in the regime of metastable states, is uniformly distributed when averaging globally, and vanishes when a fixed point is near a stored pattern. Based on the Hopfield network interpretation, we analyzed learning of transformer and BERT architectures. Learning starts with attention heads that average and then most of them switch to metastable states. However, the majority of heads in the first layers still averages and can be replaced by averaging operations like the Gaussian weighting that we propose. In contrast, heads in the last layers steadily learn and seem to use metastable states to collect information created in lower layers. These heads seem a promising target for improving transformers. Neural networks that integrate Hopfield networks that are equivalent to attention heads outperform other methods on immune repertoire classification, where the Hopfield net stores several hundreds of thousands of patterns.

With this repository, we provide a PyTorch implementation of a new layer called “Hopfield” which allows to equip deep learning architectures with Hopfield networks as new memory concepts.

The full paper is available at https://arxiv.org/abs/2008.02217.

Requirements

The software was developed and tested on the following 64-bit operating systems:

  • CentOS Linux release 8.1.1911 (Core)
  • macOS 10.15.5 (Catalina)

As the development environment, Python 3.8.3 in combination with PyTorch 1.6.0 was used (a version of at least 1.5.0 should be sufficient). More details on how to install PyTorch are available on the official project page.

Installation

The recommended way to install the software is to use pip/pip3:

$ pip3 install git+https://github.com/ml-jku/hopfield-layers

To successfully run the Jupyter notebooks contained in examples, additional third-party modules are needed:

$ pip3 install -r examples/requirements.txt

The installation of the Jupyter software itself is not covered. More details on how to install Jupyter are available at the official installation page.

Usage

To get up and running with Hopfield-based networks, only one argument needs to be set, the size (depth) of the input.

from hflayers import Hopfield

hopfield = Hopfield(input_size=...)

It is also possible to replace commonly used pooling functions with a Hopfield-based one. Internally, a state pattern is trained, which in turn is used to compute pooling weights with respect to the input.

from hflayers import HopfieldPooling

hopfield_pooling = HopfieldPooling(input_size=...)

A second variant of our Hopfield-based modules is one which employs a trainable but fixed lookup mechanism. Internally, one or multiple stored patterns and pattern projections are trained (optionally in a non-shared manner), which in turn are used as a lookup mechanism independent of the input data.

from hflayers import HopfieldLayer

hopfield_lookup = HopfieldLayer(input_size=...)

The usage is as simple as with the main module, but equally powerful.

Examples

Generally, the Hopfield layer is designed to be used to implement or to substitute different layers like:

  • Pooling layers: We consider the Hopfield layer as a pooling layer if only one static state (query) pattern exists. Then, it is de facto a pooling over the sequence, which results from the softmax values applied on the stored patterns. Therefore, our Hopfield layer can act as a pooling layer.

  • Permutation equivariant layers: Our Hopfield layer can be used as a plug-in replacement for permutation equivariant layers. Since the Hopfield layer is an associative memory it assumes no dependency between the input patterns.

  • GRU & LSTM layers: Our Hopfield layer can be used as a plug-in replacement for GRU & LSTM layers. Optionally, for substituting GRU & LSTM layers, positional encoding might be considered.

  • Attention layers: Our Hopfield layer can act as an attention layer, where state (query) and stored (key) patterns are different, and need to be associated.

The folder examples contains multiple demonstrations on how to use the Hopfield, HopfieldPooling as well as the HopfieldLayer modules. To successfully run the contained Jupyter notebooks, additional third-party modules like pandas and seaborn are required.

  • Bit Pattern Set: The dataset of this demonstration falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems. Each bag comprises a collection of bit pattern instances, wheres each instance is a sequence of 0s and 1s. The positive class has specific bit patterns injected, which are absent in the negative one. This demonstration shows, that Hopfield, HopfieldPooling and HopfieldLayer are capable of learning and filtering each bag with respect to the class-defining bit patterns.

  • Latch Sequence Set: We study an easy example of learning long-term dependencies by using a simple latch task, see Hochreiter and Mozer. The essence of this task is that a sequence of inputs is presented, beginning with one of two symbols, A or B, and after a variable number of time steps, the model has to output a corresponding symbol. Thus, the task requires memorizing the original input over time. It has to be noted, that both class-defining symbols must only appear at the first position of a sequence. This task was specifically designed to demonstrate the capability of recurrent neural networks to capture long term dependencies. This demonstration shows, that Hopfield, HopfieldPooling and HopfieldLayer adapt extremely fast to this specific task, concentrating only on the first entry of the sequence.

  • Attention-based Deep Multiple Instance Learning: The dataset of this demonstration falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems, see Ilse and Tomczak. Each bag comprises a collection of 28x28 grayscale images/instances, whereas each instance is a sequence of pixel values in the range of [0; 255]. The amount of instances per pag is drawn from a Gaussian with specified mean and variance. The positive class is defined by the presence of the target number/digit, whereas the negative one by its absence.

Disclaimer

Some implementations of this repository are based on existing ones of the official PyTorch repository v1.6.0 and accordingly extended and modified. In the following, the involved parts are listed:

License

This repository is BSD-style licensed (see LICENSE), except where noted otherwise.