Clad fails to use proper pullbacks for functions with reference arguments
Opened this issue · 1 comments
gojakuch commented
reproducer
#include "clad/Differentiator/Differentiator.h"
#include <iostream>
#define show(x) std::cout << #x << ": " << x << "\n";
double fff(double &x) {
return 0;
}
double fn(double u, double v) {
double &w = u;
fff(u);
return v;
}
int main() {
auto d_fn = clad::gradient(fn);
double u = 3, v = 5;
double du, dv;
du = dv = 0;
show(fn(u, v));
d_fn.execute(u, v, &du, &dv);
show(du);
show(dv);
}
an important note, same should work out for constructors:
#include "clad/Differentiator/Differentiator.h"
#include <iostream>
#define show(x) std::cout << #x << ": " << x << "\n";
class SafeTestClass {
public:
SafeTestClass() {};
SafeTestClass(double &x) {
}
};
namespace clad {
namespace custom_derivatives {
namespace class_functions {
clad::ValueAndAdjoint<SafeTestClass, SafeTestClass>
constructor_reverse_forw(clad::ConstructorReverseForwTag<SafeTestClass>, double& x, double *d_x) {
return {SafeTestClass(x), SafeTestClass(*d_x)};
}
void constructor_pullback(SafeTestClass *c, double &x, SafeTestClass *d_c, double* d_x) {
}
}}}
double fn(double u, double v) {
double &w = u;
SafeTestClass s3(w);
return v;
}
int main() {
auto d_fn = clad::gradient(fn);
double u = 3, v = 5;
double du, dv;
du = dv = 0;
show(fn(u, v));
d_fn.execute(u, v, &du, &dv);
show(du);
show(dv);
}