Clarification on using loo_moment_match() with non-Stan objects
wlandau opened this issue · 4 comments
I am working on a model averaging problem with very simple models, and I am getting intermittently high Pareto k values even on simple well-behaved simulated datasets. I would like to apply the moment matching correction to both non-longitudinal JAGS models and longitudinal Stan models. The latter case is trivially easy with moment_match = TRUE
in loo()
, but I do not have a stanfit
object in the former case.
Would you help me understand how to use loo_moment_match()
in the case where my model fit is a posterior::as_draws_df()
data frame with columns for parameters and pointwise log likelihoods? To set up a sufficiently motivating scenario, I converted the roaches example from the vignette into JAGS. I also put constrained priors on the scale parameters for the sake of learning what to do with the unconstrain_pars
, log_prob_upars
, and log_lik_i_upars
arguments of loo_moment_match()
. (Is it even appropriate to consider "unconstrained parameters" without HMC?)
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
data(roaches, package = "rstanarm")
roaches$roach1 <- sqrt(roaches$roach1)
x <- roaches[, c("roach1", "treatment", "senior")]
data <- list(
N = nrow(x),
K = ncol(x),
x = as.matrix(x),
y = roaches$y,
offset = log(roaches[,"exposure2"])
)
model_text <- "
model {
for (n in 1:N) {
y[n] ~ dpois(exp(inprod(x[n,], beta) + intercept + offset[n]))
}
for (k in 1:K) {
beta[k] ~ dnorm(0, 1 / (scale_beta * scale_beta))
}
intercept ~ dnorm(0, 1 / (scale_alpha * scale_alpha))
scale_beta ~ dunif(0, 10)
scale_alpha ~ dnorm(0, 10) T(0,)
for (n in 1:N) {
log_lik[n] <- log(dpois(y[n], exp(inprod(x[n,], beta) + intercept + offset[n])))
}
}
"
file <- tempfile()
writeLines(model_text, file)
tmp <- capture.output({
model <- rjags::jags.model(
file = file,
data = data,
n.chains = 4,
n.adapt = 2e3
)
stats::update(model, n.iter = 2e3, quiet = TRUE)
coda <- rjags::coda.samples(
model = model,
variable.names = c(
"beta",
"intercept",
"scale_beta",
"scale_alpha",
"log_lik"
),
n.iter = 4e3
)
})
fit <- posterior::as_draws_df(coda)
print(fit) # This is the model fit object I can work with.
#> # A draws_df: 4000 iterations, 4 chains, and 268 variables
#> beta[1] beta[2] beta[3] intercept log_lik[1] log_lik[2] log_lik[3]
#> 1 0.16 -0.55 -0.30 2.5 -19 -16 -2.1
#> 2 0.16 -0.55 -0.29 2.5 -19 -16 -2.1
#> 3 0.16 -0.52 -0.27 2.5 -17 -14 -2.1
#> 4 0.16 -0.55 -0.38 2.5 -16 -14 -2.1
#> 5 0.16 -0.56 -0.25 2.5 -16 -14 -2.1
#> 6 0.16 -0.59 -0.26 2.5 -18 -15 -2.1
#> 7 0.16 -0.60 -0.25 2.5 -19 -16 -2.0
#> 8 0.16 -0.58 -0.34 2.5 -18 -16 -2.0
#> 9 0.16 -0.58 -0.30 2.5 -18 -15 -2.1
#> 10 0.16 -0.60 -0.33 2.5 -17 -15 -2.1
#> log_lik[4]
#> 1 -2.2
#> 2 -2.2
#> 3 -2.3
#> 4 -2.2
#> 5 -2.2
#> 6 -2.2
#> 7 -2.2
#> 8 -2.2
#> 9 -2.2
#> 10 -2.2
#> # ... with 15990 more draws, and 260 more variables
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
# Convergence looks okay.
fit %>%
select(starts_with(c("beta", "intercept", "scale"))) %>%
posterior::summarize_draws() %>%
print()
#> Warning: Dropping 'draws_df' class as required metadata was removed.
#> # A tibble: 6 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_t…¹
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 beta[1] 0.161 0.161 0.00193 0.00194 0.158 0.164 1.00 1704. 2993.
#> 2 beta[2] -0.566 -0.566 0.0248 0.0246 -0.607 -0.524 1.00 3388. 5597.
#> 3 beta[3] -0.312 -0.312 0.0334 0.0335 -0.368 -0.259 1.00 5957. 8365.
#> 4 intercept 2.52 2.52 0.0260 0.0260 2.48 2.56 1.00 1339. 2777.
#> 5 scale_alpha 0.905 0.889 0.156 0.153 0.673 1.19 1.00 9336. 8030.
#> 6 scale_beta 0.816 0.569 0.835 0.298 0.272 2.20 1.00 2056. 864.
#> # … with abbreviated variable name ¹ess_tail
# LOO without the moment matching correction is straightforward.
log_lik <- as.matrix(dplyr::select(fit, tidyselect::starts_with("log_lik")))
#> Warning: Dropping 'draws_df' class as required metadata was removed.
r_eff <- loo::relative_eff(x = log_lik, chain_id = fit$.chain)
loo <- loo::loo(x = log_lik, r_eff = r_eff)
#> Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
# But we get high Pareto k values.
print(loo)
#>
#> Computed from 16000 by 262 log-likelihood matrix
#>
#> Estimate SE
#> elpd_loo -5462.1 696.5
#> p_loo 261.3 57.6
#> looic 10924.3 1393.0
#> ------
#> Monte Carlo SE of elpd_loo is NA.
#>
#> Pareto k diagnostic values:
#> Count Pct. Min. n_eff
#> (-Inf, 0.5] (good) 239 91.2% 537
#> (0.5, 0.7] (ok) 9 3.4% 76
#> (0.7, 1] (bad) 7 2.7% 11
#> (1, Inf) (very bad) 7 2.7% 1
#> See help('pareto-k-diagnostic') for details.
# How do I use loo_moment_match() in this situation?
# loo::loo_moment_match(
# x = fit,
# post_draws = function(x) as.matrix(x),
# log_lik_i = function(x, i) x[[sprintf("log_lik[%s]", i)]],
# unconstrain_pars = "???", # Do we even need to consider the unconstrained space for non-HMC MCMC?
# log_prob_upars = "???", # Here is where I start to get lost.
# log_lik_i_upars = "???" # Same here.
# )
Created on 2022-12-01 with reprex v2.0.2
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.2.1 (2022-06-23)
#> os macOS Big Sur ... 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/Indiana/Indianapolis
#> date 2022-12-01
#> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0)
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0)
#> checkmate 2.1.0 2022-04-21 [1] CRAN (R 4.2.0)
#> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0)
#> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0)
#> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0)
#> DBI 1.1.3 2022-06-18 [1] CRAN (R 4.2.0)
#> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0)
#> distributional 0.3.1 2022-09-02 [1] CRAN (R 4.2.0)
#> dplyr * 1.0.10 2022-09-01 [1] CRAN (R 4.2.0)
#> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.0)
#> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0)
#> farver 2.1.1 2022-07-06 [1] CRAN (R 4.2.0)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0)
#> ggplot2 3.4.0 2022-11-04 [1] CRAN (R 4.2.0)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0)
#> gtable 0.3.1 2022-09-01 [1] CRAN (R 4.2.0)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0)
#> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0)
#> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.0)
#> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1)
#> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0)
#> loo 2.5.1 2022-03-24 [1] CRAN (R 4.2.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0)
#> matrixStats 0.63.0 2022-11-18 [1] CRAN (R 4.2.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0)
#> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0)
#> posterior 1.3.1 2022-09-06 [1] CRAN (R 4.2.0)
#> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0)
#> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0)
#> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0)
#> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0)
#> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0)
#> rjags 4-13 2022-04-19 [1] CRAN (R 4.2.0)
#> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0)
#> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.0)
#> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0)
#> scales 1.2.1 2022-08-20 [1] CRAN (R 4.2.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0)
#> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0)
#> stringr 1.4.1 2022-08-20 [1] CRAN (R 4.2.0)
#> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0)
#> tensorA 0.36.2 2020-11-19 [1] CRAN (R 4.2.0)
#> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0)
#> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.1)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0)
#> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.0)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0)
#> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.0)
#> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library
#>
#> ──────────────────────────────────────────────────────────────────────────────
Are there other convenient ways to make approximate LOO more robust?
Hi, you might have some luck with using the generic moment matching functions from https://github.com/topipa/iwmm
You'll need to manually specify the target function or importance weight function but it should work on a matrix object.
Thanks, @n-kall. Is iwmm
a generic implementation of https://mc-stan.org/loo/articles/loo2-moment-matching.html, or is the underlying statistical method itself different too?
Yes, it is the same underlying mechanism, just generic (i.e. not tied to importance weights for leave-one-out posteriors). Given a log_ratio_fun
, the moment_match
function will return transformed draws and importance weights (and Pareto-k diagnostic values). k_threshold = 0.7
and split = TRUE
would match the loo_moment_match
defaults.
If you want to use it for the leave-one-out case, the log_ratio_fun
should be a function that returns the negative log likelihood of the left-out observation. See the tests for an example. You'd likely need to wrap it in a loop (for each observation) and use the resulting draws+weights to calculate the elpd or other metrics you're interested in.