/ptrnets

Collection of pretrained networks in pytorch readily available for transfer learning tasks (e.g. neural system identification)

Primary LanguagePythonMIT LicenseMIT

ptrnets

Release Build status codecov Commit activity License

Collection of pretrained networks in pytorch readily available for transfer learning tasks like neural system identification.

Installation

pip install ptrnets

Usage

Find a list of all available models like this:

from ptrnets import AVAILABLE_MODELS

print(AVAILABLE_MODELS)

Import a model like this:

from ptrnets import simclr_resnet50x2

model = simclr_resnet50x2(pretrained=True)

You can access intermediate representations in two ways:

Probing the model

You can conveniently access intermediate representations of a forward pass using the ptrnets.utils.mlayer.probe_model function Example:

import torch
from ptrnets import resnet50
from ptrnets.utils.mlayer import probe_model

model = resnet50(pretrained=True)
available_layers = [name for name, _ in model.named_modules()]
layer_name = "layer2.1"
assert layer_name in available_layers, f"Layer {layer_name} not available. Choose from {available_layers}"

model_probe = probe_model(model, layer_name)

x = torch.rand(1, 3, 224, 224)
output = model_probe(x)

Note: if the input is not large enough to do a full forward pass through the network, you might need to use a try-except block to catch the RuntimeError.

Clipping the model

ptrnets.utils.mlayer.clip_model creates a copy of the model up to a specific layer. Because the model is smaller, a forward pass can run faster. However, the output is only guaranteed to be the same as the original model's if the architecture is fully sequential up until that layer.

Example:

import torch
from ptrnets import vgg16
from ptrnets.utils.mlayer import clip_model, probe_model

model = vgg16(pretrained=True)
available_layers = [name for name, _ in model.named_modules()]
layer_name = "features.18"
assert layer_name in available_layers, f"Layer {layer_name} not available. Choose from {available_layers}"

model_clipped = clip_model(model, layer_name)  # Creates new model up to the layer

x = torch.rand(1, 3, 224, 224)
output = model_clipped(x)

assert torch.allclose(output, probe_model(model, layer_name)(x)), "Output of clipped model is not the same as the original model"

Contributing

Pull requests are welcome. Please see instructions here.