Accelerated Neuroevolution of Physics-informed Neural Networks: Benchmarks and Experimental Results

This repository introduces the use of evolutionary algorithms for solving differential equations. The solution is obtained by optimizing a deep neural network whose loss function is defined by the residual terms from the differential equations. Recent studies have used stochastic gradient descent (SGD) variants to train these physics-informed neural networks (PINNs), but these methods can struggle to find accurate solutions due to optimization challenges. When solving differential equations, it is important to find the globally optimum parameters of the network, rather than just finding a solution that works well on the training data. SGD only searches along a single gradient direction, so it may not be the best approach in this case. In contrast, evolutionary algorithms perform a parallel exploration of different solutions in order to avoid getting stuck in local optima and potentially find more accurate solutions. However, evolutionary algorithms can be slow, which can make them difficult to use in practice. Hence, we have chosen to use JAX for implementation of all the benchmark problems and codes as it leverages on hardware accelerators such as GPUs and TPUs to acclerate the training and running of deep learning models.

We provide a set of five benchmark problems with associated performance metrics and baseline results to support the development of evolutionary algorithms for enhanced PINN training. We utilize the SGD and Covariance Matrix Adaptation Evolution Strategy (CMA-ES) implementations from the Optax and EvoJAX packages, respectively. The xNES+NAG algorithm has also been implemented by us using JAX. To facilitate testing of other optimizers on each of the proposed benchmark problems, the PINN architecture and training loss definitions are wrapped within a simple black-box function that can take neural network weight parameters as input and return training loss as the equivalent fitness function evaluation. The CMA-ES code are provided as examples to illustrate the interface between the benchmark suites and EvoJAX and the xNES+NAG codes are provided as examples to illustrate the plug-and-play nature of the benchmark suite.

As a baseline, we evaluate the performance and speed of using the widely-adopted CMA-ES and xNES+NAG for solving PINNs. We provide the loss and training time for CMA-ES and xNES+NAG run on JAX, and SGD run on JAX for the five benchmark problems. Our results show that evolutionary algorithms, as exemplified by probabilistic model-based CMA-ES and xNES+NAG, have the potential to be more effective than SGD for solving differential equations. We hope that our work will support the exploration and development of alternative optimization algorithms for the complex task of optimizing PINNs.