DISCLAIMER: the code is not working properly! If you still consider using, be prepated to debug it thoroughly.
Reproduction of AlphaTensor paper for 2x2 matrices. Parts of code are inspired by this repo, but strongly refactored.
An optional first step, which will make everything easier
conda create --name alphastrassen python=3.8
conda activate alphastrassen
Install torch compatible with your CUDA version
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch # for CUDA >=11.3
Then, install our project
git clone https://github.com/migonch/alphastrassen.git
cd alphastrassen
pip install -e .