/rad

rust automatic differentiation

Primary LanguageRust

rad

rad is a symbolic automatic differentiation library for rust

since it’s symbolic, it doesn’t evaluate expressions to calculate the derivative, and it mostly looks like regular code

it’s still in development and exploratory phase, should not be used in production code (or on any code tbh)

how to use

there’s no examples, and since its still in development it’s not super well documented, but the general idea is that you can write down expressions in almost the same style you’d write normal code:

let expr = cos(c(1.0f32) + sin(v(X)));

let diff = expr.diff::<X>();
let value = diff.calc_x(0.0);

the only difference is that you have to wrap constants in a c function, and the variables in a v function

how it works

functions like c, v, cos, etc, all return structs. these implement traits like std::ops::{Add, Sub, Mul, Div}, which allows you to write normal-ish code. these structs implement the Differentiable trait, which walks down this new quasi-ast we’ve constructed, and replaces each struct with it’s corresponding derivative

so for example:

  • c(1).diff() = c(0)=
  • v(X).diff() = c(1)=
  • (a + b).diff() = a.diff() + b.diff()=
  • (a * b).diff() = a.diff() * b + a * b.diff()=

other details

domain

the library doesn’t assume what domain you are differentiating in, so you could differentiate on the reals (f32, f64), integers (i32, etc), naturals (u32, etc), or literally any other thing, as long as it implements the needed traits from std::ops. that means that you can differentiate expressions that work on strings, as long as those expressions only contain additions

use std::borrow::Cow;

impl Domain for Cow<'static, str> {
  const ZERO: Self = Self::Borrowed("");

  const ONE: Self = Self::Borrowed("a");
}

let string: Cow<'static, str> = "hey".into();

let expr = c(string) + v(X);
let r = expr.diff::<X>();

assert_eq!(r, c(Cow::Borrowed("")) + c(Cow::Borrowed("a")));

(technically it’s on Cow<'static, str>, since it’s the only string type that implements std::ops::Add, as far as i know)

as you can see in the example above, we do have to implement Domain for the type, to define the neutral elements for addition (Domain::ZERO) and multiplication (DOMAIN::ONE)

multiple variables

rad supports as many variables as you want. X, Y, and Z are defined for convenience, but you can define your own custom variables easily

multivariate example:

let f = cos(v(X)) + sin(v(Y)) + c(1.0f32);

// df / dx
let rx = f.diff::<X>();
// df / dy
let ry = f.diff::<Y>();

// d^2 f / dxdy
let rxy = rx.diff::<Y>();
let ryx = ry.diff::<X>();

note: even though mathematically rxy = rxy= (if X and Y are independent), it’s not guaranteed that rxy and ryx will actually be equal in rust code, as currently there’s no simplification, and there will probably be differences in how both expressions are represented internally, which will cause the ==== operator to return false

defining variables is easy:

#[derive(PartialEq, Debug, Clone)]
pub struct MyVar;
impl Var for MyVar {}

let f = cos(v(MyVar));