/tensornets

High level network definitions with pre-trained weights in TensorFlow

Primary LanguagePythonMIT LicenseMIT

TensorNets

High level network definitions with pre-trained weights in TensorFlow (tested with >= 1.2.0).

Guiding principles

  • Applicability. Many people already have their own ML workflows, and want to put a new model on their workflows. TensorNets can be easily plugged together because it is designed as simple functional interfaces without custom classes.
  • Manageability. Models are written in tf.contrib.layers, which is lightweight like PyTorch and Keras, and allows for ease of accessibility to every weight and end-point. Also, it is easy to deploy and expand a collection of pre-processing and pre-trained weights.
  • Readability. With recent TensorFlow APIs, more factoring and less indenting can be possible. For example, all the inception variants are implemented as about 500 lines of code in TensorNets while 2000+ lines in official TensorFlow models.

A quick example

Each network (see full list) is not a custom class but a function that takes and returns tf.Tensor as its input and output. Here is an example of ResNet50:

import tensorflow as tf
import tensornets as nets

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
model = nets.ResNet50(inputs)

assert isinstance(model, tf.Tensor)

You can load an example image by using utils.load_img returning a np.ndarray as the NHWC format:

from tensornets import utils
img = utils.load_img('cat.png', target_size=256, crop_size=224)
assert img.shape == (1, 224, 224, 3)

Once your network is created, you can run with regular TensorFlow APIs 😊 because all the networks in TensorNets always return tf.Tensor. Using pre-trained weights and pre-processing are as easy as pretrained() and preprocess() to reproduce the original results:

with tf.Session() as sess:
    nets.pretrained(model)
    img = nets.preprocess(model, img)
    preds = sess.run(model, {inputs: img})

You can see the most probable classes:

print(utils.decode_predictions(preds, top=2)[0])
[(u'n02124075', u'Egyptian_cat', 0.28067636), (u'n02127052', u'lynx', 0.16826575)]

TensorNets enables us to deploy well-known architectures and benchmark those results faster ⚡️. For more information, you can check out the lists of utilities, examples, and architectures.

Utilities

An example output of utils.print_summary(model):

Scope: resnet50
Total layers: 54
Total weights: 320
Total parameters: 25,636,712

An example output of utils.print_weights(model):

Scope: resnet50
conv1/conv/weights:0 (7, 7, 3, 64)
conv1/conv/biases:0 (64,)
conv1/bn/beta:0 (64,)
conv1/bn/gamma:0 (64,)
conv1/bn/moving_mean:0 (64,)
conv1/bn/moving_variance:0 (64,)
conv2/block1/0/conv/weights:0 (1, 1, 64, 256)
conv2/block1/0/conv/biases:0 (256,)
conv2/block1/0/bn/beta:0 (256,)
conv2/block1/0/bn/gamma:0 (256,)
...
  • utils.get_weights(model) returns a list of all the tf.Tensor weights as shown in the above

An example output of utils.print_outputs(model):

Scope: resnet50
conv1/pad:0 (?, 230, 230, 3)
conv1/conv/BiasAdd:0 (?, 112, 112, 64)
conv1/bn/batchnorm/add_1:0 (?, 112, 112, 64)
conv1/relu:0 (?, 112, 112, 64)
pool1/pad:0 (?, 114, 114, 64)
pool1/MaxPool:0 (?, 56, 56, 64)
conv2/block1/0/conv/BiasAdd:0 (?, 56, 56, 256)
conv2/block1/0/bn/batchnorm/add_1:0 (?, 56, 56, 256)
conv2/block1/1/conv/BiasAdd:0 (?, 56, 56, 64)
conv2/block1/1/bn/batchnorm/add_1:0 (?, 56, 56, 64)
conv2/block1/1/relu:0 (?, 56, 56, 64)
...
  • utils.get_outputs(model) returns a list of all the tf.Tensor end-points as shown in the above

Examples

  • Comparison of different networks:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = [
    nets.MobileNet75(inputs),
    nets.MobileNet100(inputs),
    nets.SqueezeNet(inputs),
]

img = utils.load_img('cat.png', target_size=256, crop_size=224)
imgs = nets.preprocess(models, img)

with tf.Session() as sess:
    nets.pretrained(models)
    for (model, img) in zip(models, imgs):
        preds = sess.run(model, {inputs: img})
        print(utils.decode_predictions(preds, top=2)[0])
  • Transfer learning:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
outputs = tf.placeholder(tf.float32, [None, 50])
model = nets.DenseNet169(inputs, is_training=True, classes=50)

loss = tf.losses.softmax_cross_entropy(outputs, model)
train = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss)

with tf.Session() as sess:
    nets.pretrained(model)
    # for (x, y) in your NumPy data (the NHWC and one-hot format):
        sess.run(train, {inputs: x, outputs: y})
  • Using multi-GPU:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = []

with tf.device('gpu:0'):
    models.append(nets.ResNeXt50(inputs))

with tf.device('gpu:1'):
    models.append(nets.DenseNet201(inputs))

from tensornets.preprocess import fb_preprocess
img = utils.load_img('cat.png', target_size=256, crop_size=224)
img = fb_preprocess(img)

with tf.Session() as sess:
    nets.pretrained(models)
    preds = sess.run(models, {inputs: img})
    for pred in preds:
        print(utils.decode_predictions(pred, top=2)[0])

Performances

  • The top-k errors were obtained with TensorNets (single center crop 224x224 except Inception3,4,ResNet2 and ResNet50-152v2 299x299) and may slightly differ from the original ones.
  • The computation times were measured on NVIDIA Tesla P100 (3584 cores, 16 GB global memory) with cuDNN 6.0 and CUDA 8.0.
Top-1 error Top-5 error Speed (ms) References
ResNet50 25.076 7.884 195.4 [paper] [tf-slim] [torch-fb] [caffe] [keras]
ResNet101 23.574 7.208 311.7 [paper] [tf-slim] [torch-fb] [caffe]
ResNet152 23.362 6.914 439.1 [paper] [tf-slim] [torch-fb] [caffe]
ResNet50v2 24.442 7.174 209.7 [paper] [tf-slim] [torch-fb]
ResNet101v2 23.064 6.476 326.2 [paper] [tf-slim] [torch-fb]
ResNet152v2 22.300 6.066 455.2 [paper] [tf-slim] [torch-fb]
ResNet200v2 21.898 5.998 618.3 [paper] [tf-slim] [torch-fb]
ResNeXt50 22.518 6.418 267.4 [paper] [torch-fb]
ResNeXt101 21.426 5.928 427.9 [paper] [torch-fb]
WideResNet50 22.308 6.238 358.1 [paper] [torch]
Inception1 32.962 12.122 165.1 [paper] [tf-slim] [caffe-zoo]
Inception2 26.420 8.450 134.3 [paper] [tf-slim]
Inception3 22.092 6.220 314.6 [paper] [tf-slim] [keras]
Inception4 19.854 5.032 582.1 [paper] [tf-slim]
InceptionResNet2 19.660 4.806 656.8 [paper] [tf-slim]
DenseNet121 25.550 8.174 202.9 [paper] [torch]
DenseNet169 24.092 7.172 219.1 [paper] [torch]
DenseNet201 22.988 6.700 272.0 [paper] [torch]
MobileNet25 48.346 24.150 29.27 [paper] [tf-slim]
MobileNet50 35.594 14.390 42.32 [paper] [tf-slim]
MobileNet75 31.520 11.710 57.23 [paper] [tf-slim]
MobileNet100 29.474 10.416 70.69 [paper] [tf-slim]
SqueezeNet 44.656 21.432 71.43 [paper] [caffe]