Update the derivative function's signature
Opened this issue · 0 comments
Clad currently supports the differentiation of a function with regards to a specific argument. This results in not being able to guess the derivative function's signature at compile time to assign the DerivedFnType
template. Hence, a wrapper function is being produced with the signature: void (*)(Args..., OutputParamType_t<Args, void>...)
, where OutputParamType_t
appends an extra arg of type void *
for each arg of the function. This function refers to the Overload function, as mentioned in the source, but the user assumes that the function being returned is the one with a signature: void (*)(Args..., $_d_args)
, where _d_args
are the derivatives of the function with respect to the args the user specified in the differentiation function. For instance:
double foo(double a, double b);
// clad::gradient(foo, "b");
void foo_grad(double a, double b, void *_d_a, void *_d_b); // overload function
void foo_grad(double a, double b, double *_d_b); // derivative function
The overload function includes a call to the derivative function and performs the necessary typecasting:
double foo(double a, double b) {
a *= b;
return a;
}
void foo_grad_1(double a, double b, double *_d_b) { // derivative function
double _d_a = 0.;
double _t0 = a;
a *= b;
_d_a += 1;
{
a = _t0;
double _r_d0 = _d_a;
_d_a = 0.;
_d_a += _r_d0 * b;
*_d_b += a * _r_d0;
}
}
void foo_grad_1(double a, double b, void *_temp__d_b0, void *_d_1) { // overloaded function
double *_d_b = (double *)_temp__d_b0;
foo_grad_1(a, b, _d_b);
}
The overload function is the one actually returned to the user, so this is the one executed every time. The user still provides the args thinking they call the internal derivative function, so Clad appends nullptr
to the derivative args not used which are specified in the end of the function's signature.
One issue with this approach arose when handling global CUDA kernels. In this case, the derivative function has to be made a device function and the overload is the actual global kernel executed. However, if the user utilized shared memory inside the original kernel, this can not be cloned into the device function as shared memory can only be declared inside a global function in CUDA.
Updating the templates after they're initialized is pretty complicated especially since a constructor of a templated class is being called inside clad::gradient
.
An alternative to the overloaded function could look like this:
// clad::gradient(foo, "b");
void foo_grad_1_b(double a, double b, double *_d_a, double *_d_b) { // derivative function
double _d_a0 = 0.;
double _t0 = a;
a *= b;
_d_a0 += 1;
{
a = _t0;
double _r_d0 = _d_a0;
_d_a0 = 0.;
_d_a0 += _r_d0 * b;
*_d_b += a * _r_d0;
}
}
This way:
- We know the function signature at compile time: void (*)(Args..., $Args_pointers)
- The users don't have to create all the adjoint variables, they can still use
nullptr
when they want to - The derivative function has a different name according to the arg it's being derived by, so the same function can be derived using every possible arg combination without any conflict
There's also the argument of losing the capability of differentiating based on certain args. Hence, the user must provide every adjoint.
This issue aims to gather every argument and idea for the way forward and arrive to a conclusion.