stan-dev/math

Improve numerical stability of von_mises_lpdf

venpopov opened this issue · 0 comments

Description

Currently the von_mises_lpdf function can overflow for large values of kappa. Because of this, using it in stan usually requires a conditional statement to use a normal distribution approximation for kappa > 100. For example, based on the recommendation in the Stan Functions Reference in brms the von_mises distribution is redifined as

real von_mises2_lpdf(vector y, vector mu, real kappa) {
     if (kappa < 100) {
       return von_mises_lpdf(y | mu, kappa);
     } else {
       return normal_lpdf(y | mu, sqrt(1 / kappa));
     }
   }

Because of this, the likelihood cannot be vectorized and brms needs to loop over the observations if kappa varies across conditions. It also still leads to issues, because sqrt(1/kappa) becomes 0 for large kappas.

A simple change can substantially improve the stability of the function and make that conditional approximation unnecessary.

There exist the log_modified_bessel_first_kind() function, but this is only used for the log likelihood calculation, but not for the partials. On lines 81-86:

    edge<2>(ops_partials).partials_
        = cos_mu_minus_y
          - modified_bessel_first_kind(1, kappa_val)
                / modified_bessel_first_kind(0, kappa_val);

the value of $I_1(k)/I_0(k)$ is always between 0 and 1:

image

but currently both the numerator and denominator overflow for large kappas.

This can be replaced by using the log_modified_bessel_first_kind function:

    edge<2>(ops_partials).partials_
        = cos_mu_minus_y
          - exp(log_modified_bessel_first_kind(1, kappa_val)
                - log_modified_bessel_first_kind(0, kappa_val));

Consequences

I ran the following two models, with the current von_mises_lpdf code on the develop branch, and with my proposed changes:

Model 1: original model generated by brms:

library(brms)
library(cmdstanr)
stancode(bf(dev_rad ~ 1, kappa ~ 1 + (1|ID)), 
                bmm::OberauerLin_2017,
                von_mises())

truncated output below. requires conditional statement for kappa, making vectorization impossible

functions {
  // more stuff here
   real von_mises2_lpdf(real y, real mu, real kappa) {
     if (kappa < 100) {
       return von_mises_lpdf(y | mu, kappa);
     } else {
       return normal_lpdf(y | mu, sqrt(1 / kappa));
     }
   }
}
// truncated other blocks ....
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    // initialize linear predictor term
    vector[N] kappa = rep_vector(0.0, N);
    mu += Intercept;
    kappa += Intercept_kappa;
    for (n in 1:N) {
      // add more terms to the linear predictor
      kappa[n] += r_1_kappa_1[J_1[n]] * Z_1_kappa_1[n];
    }
    mu = inv_tan_half(mu);
    kappa = exp(kappa);
    for (n in 1:N) {
      target += von_mises2_lpdf(Y[n] | mu[n], kappa[n]);
    }
  }
  // priors including constants
  target += lprior;
  target += std_normal_lpdf(z_1[1]);
}

Model 2: Vectorized version without conditional approximation:

Truncated model code:

model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    // initialize linear predictor term
    vector[N] kappa = rep_vector(0.0, N);
    mu += Intercept;
    kappa += Intercept_kappa;
    for (n in 1:N) {
      // add more terms to the linear predictor
      kappa[n] += r_1_kappa_1[J_1[n]] * Z_1_kappa_1[n];
    }
    mu = inv_tan_half(mu);
    kappa = exp(kappa);
    target += von_mises_lpdf(Y | mu, kappa);
  }
  // priors including constants
  target += lprior;
  target += std_normal_lpdf(z_1[1]);
}

results

I recorded the sampling time (on a Mac M3 max) and the warnings about rejected proposals and overflow.

model stanmath version sampling time warnings
1 (original) current 130 s. 3 rejected proposals due to scale=0 for normal_lpdf
2 (vectorized) current 84 s. 8 rejected proposals due to numeric overflow
1 (original) changed as above 184 s. 2 rejected proposals due to scale=0 for normal_lpdf
2 (vectorized) changed as above 107 s. none

Despite the warnings of the other models, all models converged to similar posterior estimates.

Using the log_modified_bessel_first_kind() function ads an overhead because it is slower than the non-log version. But the end results is still faster because it allows vectorization. And it produced no warnings about overflow in 10 testing runs I did.

Current Version:

v4.8.1