/jaxpi

Primary LanguagePythonOtherNOASSERTION

JAX-PI

This repository is a comprehensive implementation of physics-informed neural networks (PINNs), seamlessly integrating several advanced network architectures, training algorithms from these papers

This repository also releases an extensive range of benchmarking examples, showcasing the effectiveness and robustness of our implementation. Our implementation supports both single and multi-GPU training, while evaluation is currently limited to single-GPU setups.

Updates

  • May 2024: We have released the code for our latest paper, "PirateNets: Physics-informed Deep Learning with Residual Adaptive Networks". Please see repo branch pirate for the implementation and examples.

Installation

Ensure that you have Python 3.8 or later installed on your system. Our code is GPU-only. We highly recommend using the most recent versions of JAX and JAX-lib, along with compatible CUDA and cuDNN versions. The code has been tested and confirmed to work with the following versions:

  • JAX 0.4.26
  • CUDA 12.4
  • cuDNN 8.9

You can install the latest versions of JAX and JAX-lib with the following commands:

pip3 install -U pip
pip3 install --upgrade jax jaxlib

Install JAX-PI with the following commands:

git clone https://github.com/PredictiveIntelligenceLab/jaxpi.git
cd jaxpi
pip install .

Quickstart

We use Weights & Biases to log and monitor training metrics. Please ensure you have Weights & Biases installed and properly set up with your account before proceeding. You can follow the installation guide provided here.

To illustrate how to use our code, we will use the advection equation as an example. First, navigate to the advection directory within the examples folder:

cd jaxpi/examples/advection

To train the model, run the following command:

python3 main.py 

To customize your experiment configuration, you may want to specify a different config file as follows:

python3 main.py --config=configs/sota.py 

Our code automatically supports multi-GPU execution. You can specify the GPUs you want to use with the CUDA_VISIBLE_DEVICES environment variable. For example, to use the first two GPUs (0 and 1), use the following command:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py

Note on Memory Usage: Different models and examples may require varying amounts of GPU memory. If you encounter an out-of-memory error, you can decrease the batch size using the --config.batch_size_per_device option.

To evaluate the model's performance, you can switch to evaluation mode with the following command:

python3 main.py --config.mode=eval

Examples

In the following table, we present a comparison of various benchmarks. Each row contains information about the specific benchmark, its relative $L^2$ error, and links to the corresponding model checkpoints and Weights & Biases logs.

Benchmark Relative $L^2$ Error Checkpoint Weights & Biases
Allen-Cahn equation $5.37 \times 10^{-5}$ allen_cahn allen_cahn
Advection equation $6.88 \times 10^{-4}$ adv adv
Stokes flow $8.04 \times 10^{-5}$ stokes stokes
Kuramoto–Sivashinsky equation $1.61 \times 10^{-1}$ ks ks
Lid-driven cavity flow $1.58 \times 10^{-1}$ ldc ldc
Navier–Stokes flow in tori $3.53 \times 10^{-1}$ ns_tori ns_tori
Navier–Stokes flow around a cylinder - ns_cylinder ns_cylinder

Decaying Navier-Stokes flow in tori

ns_tori

Vortex shedding

ns_cylinder

ns_cylinder

ns_cylinder

Grey-Scott

Grey-Scott

Ginzburg–Landau

Ginzburg–Landau

Citation

@article{wang2023expert,
  title={An Expert's Guide to Training Physics-informed Neural Networks},
  author={Wang, Sifan and Sankaran, Shyam and Wang, Hanwen and Perdikaris, Paris},
  journal={arXiv preprint arXiv:2308.08468},
  year={2023}
}

@article{wang2024piratenets,
  title={PirateNets: Physics-informed Deep Learning with Residual Adaptive Networks},
  author={Wang, Sifan and Li, Bowen and Chen, Yuhan and Perdikaris, Paris},
  journal={arXiv preprint arXiv:2402.00326},
  year={2024}
}