A JAX/Flax implementation of the RAFT optical flow estimator (https://arxiv.org/abs/2003.12039), ported from PyTorch (https://docs.pytorch.org/vision/main/models/raft.html). Checkpoints have been ported, too. The implementation has been tested to reproduce the original results.
With pre-trained checkpoints, jax-raft achieves the following metrics on Sintel (train), compared to the original PyTorch implementation. This comparison uses the raft_large_C_T_SKHT_V2 and raft_small_C_T_V2 checkpoints, respectively. FPS have been computed on a single RTX 3090 Ti.
| Model | EPE (clean) ↓ | EPE (final) ↓ | FPS |
|---|---|---|---|
raft_large (jax-raft) |
0.649 | 1.020 | 11.8 |
| raft_large (PyTorch) | 0.649 | 1.020 | 8.1 |
raft_small (jax-raft) |
1.993 | 3.268 | 36.6 |
| raft_small (PyTorch) | 1.998 | 3.279 | 15.0 |
from jax_raft import raft_large # or raft_small
model, variables = raft_large(pretrained=True)
model.apply(variables, image1, image2, train=False)pip install git+https://github.com/alebeck/jax-raftIn the scripts directory, we provide scripts for converting official PyTorch RAFT checkpoints to Flax; and for validation on Sintel. The examples directory contains example usage scripts.