Improve numerical stability of von_mises_lpdf
venpopov opened this issue · 0 comments
Description
Currently the von_mises_lpdf function can overflow for large values of kappa. Because of this, using it in stan usually requires a conditional statement to use a normal distribution approximation for kappa > 100. For example, based on the recommendation in the Stan Functions Reference in brms the von_mises distribution is redifined as
real von_mises2_lpdf(vector y, vector mu, real kappa) {
if (kappa < 100) {
return von_mises_lpdf(y | mu, kappa);
} else {
return normal_lpdf(y | mu, sqrt(1 / kappa));
}
}
Because of this, the likelihood cannot be vectorized and brms needs to loop over the observations if kappa varies across conditions. It also still leads to issues, because sqrt(1/kappa) becomes 0 for large kappas.
A simple change can substantially improve the stability of the function and make that conditional approximation unnecessary.
There exist the log_modified_bessel_first_kind() function, but this is only used for the log likelihood calculation, but not for the partials. On lines 81-86:
edge<2>(ops_partials).partials_
= cos_mu_minus_y
- modified_bessel_first_kind(1, kappa_val)
/ modified_bessel_first_kind(0, kappa_val);
the value of
but currently both the numerator and denominator overflow for large kappas.
This can be replaced by using the log_modified_bessel_first_kind function:
edge<2>(ops_partials).partials_
= cos_mu_minus_y
- exp(log_modified_bessel_first_kind(1, kappa_val)
- log_modified_bessel_first_kind(0, kappa_val));
Consequences
I ran the following two models, with the current von_mises_lpdf code on the develop branch, and with my proposed changes:
Model 1: original model generated by brms:
library(brms)
library(cmdstanr)
stancode(bf(dev_rad ~ 1, kappa ~ 1 + (1|ID)),
bmm::OberauerLin_2017,
von_mises())
truncated output below. requires conditional statement for kappa, making vectorization impossible
functions {
// more stuff here
real von_mises2_lpdf(real y, real mu, real kappa) {
if (kappa < 100) {
return von_mises_lpdf(y | mu, kappa);
} else {
return normal_lpdf(y | mu, sqrt(1 / kappa));
}
}
}
// truncated other blocks ....
model {
// likelihood including constants
if (!prior_only) {
// initialize linear predictor term
vector[N] mu = rep_vector(0.0, N);
// initialize linear predictor term
vector[N] kappa = rep_vector(0.0, N);
mu += Intercept;
kappa += Intercept_kappa;
for (n in 1:N) {
// add more terms to the linear predictor
kappa[n] += r_1_kappa_1[J_1[n]] * Z_1_kappa_1[n];
}
mu = inv_tan_half(mu);
kappa = exp(kappa);
for (n in 1:N) {
target += von_mises2_lpdf(Y[n] | mu[n], kappa[n]);
}
}
// priors including constants
target += lprior;
target += std_normal_lpdf(z_1[1]);
}
Model 2: Vectorized version without conditional approximation:
Truncated model code:
model {
// likelihood including constants
if (!prior_only) {
// initialize linear predictor term
vector[N] mu = rep_vector(0.0, N);
// initialize linear predictor term
vector[N] kappa = rep_vector(0.0, N);
mu += Intercept;
kappa += Intercept_kappa;
for (n in 1:N) {
// add more terms to the linear predictor
kappa[n] += r_1_kappa_1[J_1[n]] * Z_1_kappa_1[n];
}
mu = inv_tan_half(mu);
kappa = exp(kappa);
target += von_mises_lpdf(Y | mu, kappa);
}
// priors including constants
target += lprior;
target += std_normal_lpdf(z_1[1]);
}
results
I recorded the sampling time (on a Mac M3 max) and the warnings about rejected proposals and overflow.
model | stanmath version | sampling time | warnings |
---|---|---|---|
1 (original) | current | 130 s. | 3 rejected proposals due to scale=0 for normal_lpdf |
2 (vectorized) | current | 84 s. | 8 rejected proposals due to numeric overflow |
1 (original) | changed as above | 184 s. | 2 rejected proposals due to scale=0 for normal_lpdf |
2 (vectorized) | changed as above | 107 s. | none |
Despite the warnings of the other models, all models converged to similar posterior estimates.
Using the log_modified_bessel_first_kind() function ads an overhead because it is slower than the non-log version. But the end results is still faster because it allows vectorization. And it produced no warnings about overflow in 10 testing runs I did.
Current Version:
v4.8.1