/SparseModernHopfield

[NeurIPS 2023] On Sparse Modern Hopfield Model

Primary LanguagePythonMIT LicenseMIT

On Sparse Modern Hopfield Model

This is the code of the paper On Sparse Modern Hopfield Model. You can use this repo to reproduce the results in the paper.

Citations

Please consider citing our paper in your publications if it helps. Here is the bibtex:

@inproceedings{
  hu2023sparse,
  title={On Sparse Modern Hopfield Model},
  author={Jerry Yao-Chieh Hu and Donglin Yang and Dennis Wu and Chenwei Xu and Bo-Yu Chen and Han Liu},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023},
  url={https://arxiv.org/abs/2309.12673}
}

Environmental Setup

You can set up the experimental environment by running the following command line:

$ conda create -n sparse_hopfield python=3.8
$ conda activate sparse_hopfield
$ pip3 install -r requirements.txt

Examples

In layers.py, we have implemented the general sparse Hopfield, dense Hopfield and sparse Hopfield. To use it, see below

dense_hp = HopfieldPooling(
    d_model=d_model,
    n_heads=n_heads,
    mix=True,
    update_steps=update_steps,
    dropout=dropout,
    mode="softmax",
    scale=scale,
    num_pattern=num_pattern) # Dense Hopfield

sparse_hp = HopfieldPooling(
    d_model=d_model,
    n_heads=n_heads,
    mix=True,
    update_steps=update_steps,
    dropout=dropout,
    mode="sparsemax",
    scale=scale,
    num_pattern=num_pattern) # Sparse Hopfield

entmax_hp = HopfieldPooling(
    d_model=d_model,
    n_heads=n_heads,
    mix=True,
    update_steps=update_steps,
    dropout=dropout,
    mode="entmax",
    scale=scale,
    num_pattern=num_pattern) # Hopfield with Entmax-15

gsh_hp = HopfieldPooling(
    d_model=d_model,
    n_heads=n_heads,
    mix=True,
    update_steps=update_steps,
    dropout=dropout,
    mode="gsh",
    scale=scale,
    num_pattern=num_pattern) # Generalized Sparse Hopfield with learnable alpha

Experimental Validation of Theoretical Results

Plotting

$ python3 Plotting.py

Multiple Instance Learning(MIL) Tasks

MNIST MIL Experiments

(There might be some potential instability in bit pattern exps so please refer to the MNIST MIL exp for now.)

$ python3 mnist_mil_main.py --bag_size <BAG_SIZE>

Bag Size 5 (default setting)

Bag Size 20 (default setting)

Bag Size 30 (default setting)

Bag Size 50 (default setting)

Bag Size 80 (default setting)

Bag Size 100 (dropout = 0.1)

Real-World MIL Tasks

Dataset preparation

Download and upzip the dataset

$ wget http://www.cs.columbia.edu/~andrews/mil/data/MIL-Data-2002-Musk-Corel-Trec9-MATLAB.tgz 
$ wget http://www.cs.columbia.edu/~andrews/mil/data/MIL-Data-2002-Musk-Corel-Trec9-MATLAB.tgz 
$ tar zxvf ./MIL-Data-2002-Musk-Corel-Trec9-MATLAB.tgz 

Training and Evaluation

$ python3 real_world_mil.py --dataset <DATASET> --mode <MODE>

Argument options

  • dataset: fox, tiger, ucsb, elephant
  • mode: sparse, standard
  • cpus_per_trial: how many cpus do u want to use for a single run (set this up carefully for hyperparameter tuning)
  • gpus_per_trial: how many gpus do u want to use for a single run (set this up carefully for hyperparameter tuning) (no larger than 1)
  • gpus_id: specify which gpus u want to use (e.g. --gpus_id=0, 1 means cuda:0 and cuda:1 are used for this script)

Acknowledgment

The authors would like to thank the anonymous reviewers and program chairs for constructive comments.

JH is partially supported by the Walter P. Murphy Fellowship. HL is partially supported by NIH R01LM1372201, NSF CAREER1841569, DOE DE-AC02-07CH11359, DOE LAB 20-2261 and a NSF TRIPODS1740735. This research was supported in part through the computational resources and staff contributions provided for the Quest high performance computing facility at Northwestern University which is jointly supported by the Office of the Provost, the Office for Research, and Northwestern University Information Technology. The content is solely the responsibility of the authors and does not necessarily represent the official views of the funding agencies.

The experiments in this work benefit from the following open-source codes: