Image models based on Apple MLX framework for Apple Silicon machines.
Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.
This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz and the models are not trained again.
I don't have enough compute power (and time) to train all the models from scratch (someone buy me a maxed-out Mac, please).
pip install mlx-image
Models weights are available on mlx-vision
space on HuggingFace.
To create a model with weights:
from mlxim.model import create_model
# loading weights from HuggingFace
model = create_model("resnet18")
# loading weights from local file
model = create_model("resnet18", weights="path/to/weights.npz")
To list all available models:
from mlxim.model import list_models
list_models()
Warning
As of today (2024-03-05) mlx does not support nn.Conv2d with group
or dilation
greater than 1 (e.g. resnext
, regnet
, efficientnet
).
Go to results-imagenet-1k.csv to check every model converted and its performance on ImageNet-1K with different settings.
Training is similar to PyTorch, thanks to some tools mlx-im
provides. Here's an example of how to train a model with mlx-im
:
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader
train_dataset = LabelFolderDataset(
root_dir="path/to/train",
class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)
def train_step(model, inputs, targets):
logits = model(inputs)
loss = mx.mean(nn.losses.cross_entropy(logits, target))
return loss
model.train()
for epoch in range(10):
for batch in train_loader:
x, target = batch
train_step_fn = nn.value_and_grad(model, train_step)
loss, grads = train_step_fn(x, target)
optimizer.update(model, grads)
mx.eval(model.state, optimizer.state)
The validation.py
script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.
I use the configuration file config/validation.yaml
to set the parameters for the validation script.
You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this link.
mlx-im
tries to be as close as possible to PyTorch:
DataLoader
-> you can define your owncollate_fn
and also usenum_workers
to speed up the data loadingDataset
->mlx-im
already supportsLabelFolderDataset
(the good and old PyTorchImageFolder
) andFolderDataset
(a generic folder with images in it)ModelCheckpoint
-> keeps track of the best model and saves it to disk (similar to PyTorchLightning) and it also suggests early stopping
This is a work in progress, so any help is appreciated.
I am working on it in my spare time, so I can't guarantee frequent updates.
If you love coding and want to contribute, follow the instructions in CONTRIBUTING.md.
[ ] add register_model (similar to timm)
[ ] inference script (similar to train/validation)
[ ] SwinTransformer
[ ] DenseNet
[ ] MobileNet
If you have any questions, please email riccardomusmeci92@gmail.com
.