/timmdocs

Documentation for Ross Wightman's timm image model library. https://fastai.github.io/timmdocs/

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Table of Contents

Pytorch Image Models (timm)

timm is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.

Install

pip install timm

Or for an editable install,

git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

How to use

Create a model

import timm 
import torch

model = timm.create_model('resnet34')
x     = torch.randn(1, 3, 224, 224)
model(x).shape

It is that simple to create a model using timm. The create_model function is a factory method that can be used to create over 300 models that are part of the timm library.

To create a pretrained model, simply pass in pretrained=True.

pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /Users/amanarora/.cache/torch/hub/checkpoints/resnet34-43635321.pth

To create a model with a custom number of classes, simply pass in num_classes=<number_of_classes>.

import timm 
import torch

model = timm.create_model('resnet34', num_classes=10)
x     = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 10])

List Models with Pretrained Weights

timm.list_models() returns a complete list of available models in timm. To have a look at a complete list of pretrained models, pass in pretrained=True in list_models.

avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
(271,
 ['adv_inception_v3',
  'cspdarknet53',
  'cspresnet50',
  'cspresnext50',
  'densenet121'])

There are a total of 271 models with pretrained weights currently available in timm!

Search for model architectures by Wildcard

It is also possible to search for model architectures using Wildcard as below:

all_densenet_models = timm.list_models('*densenet*')
all_densenet_models
['densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'tv_densenet121']