Rust bindings for PyTorch. The goal of the tch
crate is to provide some thin wrappers
around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as
possible to the original C++ api. More idiomatic rust bindings could then be
developed on top of this. The documentation can be found on docs.rs.
The code generation part for the C api on top of libtorch comes from ocaml-torch.
This crate requires the C++ PyTorch library (libtorch) in version v1.1.0 to be available on
your system. You can either install it manually and let the build script know about
it via the LIBTORCH
environment variable. If not set, the build script will
try downloading and extracting a pre-built binary version of libtorch.
- Get
libtorch
from the PyTorch website download section and extract the content of the zip file. - Add the following to your
.bashrc
or equivalent, where/path/to/libtorch
is the path to the directory that was created when unzipping the file.
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
- You should now be able to run some examples, e.g.
cargo run --example basics
.
This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.
extern crate tch;
use tch::Tensor;
fn main() {
let t = Tensor::of_slice(&[3, 1, 4, 1, 5]);
let t = t * 2;
t.print();
}
The nn
api can be used to create neural network architectures, e.g. the following code defines
a simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.
extern crate tch;
use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;
fn net(vs: &nn::Path) -> impl Module {
nn::seq()
.add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}
pub fn run() -> failure::Fallible<()> {
let m = tch::vision::mnist::load_dir("data")?;
let vs = nn::VarStore::new(Device::Cpu);
let net = net(&vs.root());
let opt = nn::Adam::default().build(&vs, 1e-3)?;
for epoch in 1..200 {
let loss = net
.forward(&m.train_images)
.cross_entropy_for_logits(&m.train_labels);
opt.backward_step(&loss);
let test_accuracy = net
.forward(&m.test_images)
.accuracy_for_logits(&m.test_labels);
println!(
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
epoch,
f64::from(&loss),
100. * f64::from(&test_accuracy),
);
}
Ok(())
}
More details on the training loop can be found in the detailed tutorial.
The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.
The example can then be run via the following command:
cargo run --example pretrained-models -- resnet18.ot tiger.jpg
This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.
// First the image is loaded and resized to 224x224.
let image = imagenet::load_image_and_resize(image_file)?;
// A variable store is created to hold the model parameters.
let vs = tch::nn::VarStore::new(tch::Device::Cpu);
// Then the model is built on this variable store, and the weights are loaded.
let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);
vs.load(weight_file)?;
// Apply the forward pass of the model to get the logits and convert them
// to probabilities via a softmax.
let output = resnet18
.forward_t(&image.unsqueeze(0), /*train=*/ false)
.softmax(-1);
// Finally print the top 5 categories and their associated probabilities.
for (probability, class) in imagenet::top(&output, 5).iter() {
println!("{:50} {:5.2}%", class, 100.0 * probability)
}
Further examples include:
- A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks.
- Neural style transfer uses a pre-trained VGG-16 model to compose an image in the style of another image (pre-trained weights: vgg16.ot).
- Some ResNet examples on CIFAR-10.
- A tutorial showing how to deploy/run some Python trained models using TorchScript JIT.
- Some Reinforcement Learning examples using the OpenAI Gym environment. This includes a policy gradient example as well as an A2C implementation that can run on Atari games.
- A Transfer Learning Tutorial shows how to finetune a pre-trained ResNet model on a very small dataset.
tch-rs
is distributed under the terms of both the MIT license
and the Apache license (version 2.0), at your option.
See LICENSE-APACHE, LICENSE-MIT for more details.