Jam is a collection of ML models (mostly vision models for now) implemented in Flax/Haiku. It includes model implementation, as well as pretrained weights converted from the other sources.
Jam is currently written to allow easy access to some pretrained models that provide PyTorch checkpoints. These pretrained models may be used for a variety of purposes, such as transfer learning, or as feature extractor in some vision-based RL tasks. There are preliminary examples for training some of these models from scratch but they are not yet fully tested/benchmarked.
- ConvNeXt (via torchvision), flax
- ResNet (via torchvision), haiku and flax
- MVP (via https://github.com/ir413/mvp/), flax
- NFNet (via https://github.com/google-deepmind/deepmind-research/blob/master/nfnets), haiku and flax
- R3M (via https://github.com/facebookresearch/r3m/tree/main), haiku and flax
See examples.