This is the code for the paper Tensor Normal Training for Deep Learning Models.
Python 3.7.6; GCC 7.3.0; cuda 11.0.3
torch 1.8.1, torchvision 0.9.1, numpy 1.18.0, scipy 1.4.1, pytz 2019.3, psutil 5.6.7
See Demo.ipynb for results and command to produce the results.
@article{ren2021tensor,
title={Tensor Normal Training for Deep Learning Models},
author={Ren, Yi and Goldfarb, Donald},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}