stan-dev/math

`pow(base, exp)` gradient inconsistency at `base=0.0`

nhuurre opened this issue · 1 comments

The following model fails to initialize due to bad gradient if s = 1 or s = 2 but initializes fine when s = 3 and s = 4.

data {
  int<lower=1,upper=4> s;
}
parameters {
  vector[2] x;
}
model {
  if (s == 1)
    target += distance(x,x);
  else if (s == 2)
    target += pow(squared_distance(x,x), 0.5);
  else if (s == 3)
    target += pow(squared_distance(x,x), 0.5 + 1e-16);
  else if (s == 4)
    target += abs(x[1] - x[1]) + abs(x[2] - x[2]);
}

This is surprising because s = 1 looks like a well-formed (though redundant) model and in any case s=1, s=2, and s=3 should all be equivalent. (s=4 is a one-dimensional version of s=1)

The behaviour for s=2 and s=3 is different because pow(base, exp) skips gradient calculation at base=0 in the general case (as if the gradient was zero) but delegates exp=0.5 case to sqrt() which always computes the mathematically correct gradient (infinity at base=0). And zero gradient ends up being accidentally correct whereas an infinite gradient is an unrecoverable error.

The mathematically correct value for the derivative of pow(base, exp) with respect to base at base=0.0 is

  • 0.0 when exp > 1.0
  • 1.0 when exp = 1.0
  • positive infinity when 0.0 < exp < 1.0
  • 0.0 when exp = 0.0
  • negative infinity when exp < 0.0

It would be reasonable to round infinity down to a finite-but-near-maximum floating point number.
Moreover, Stan only requires gradients to be correct almost everywhere. Boundary points like base=0 are relevant only in expressions like distance(x,x) where the gradients cancel anyway and the only thing that matters is whether the intermediate gradient values are finite. So the most numerically stable choice might be to "round" the infinity all the way down to zero.
I'm inclined to think pow()'s current behaviour is correct for us and sqrt() (and possibly cbrt()) should be "fixed" to also have vanishing gradient at zero.

Current Version:

v4.8.1