stan-dev/loo

Clarification on using loo_moment_match() with non-Stan objects

wlandau opened this issue · 4 comments

I am working on a model averaging problem with very simple models, and I am getting intermittently high Pareto k values even on simple well-behaved simulated datasets. I would like to apply the moment matching correction to both non-longitudinal JAGS models and longitudinal Stan models. The latter case is trivially easy with moment_match = TRUE in loo(), but I do not have a stanfit object in the former case.

Would you help me understand how to use loo_moment_match() in the case where my model fit is a posterior::as_draws_df() data frame with columns for parameters and pointwise log likelihoods? To set up a sufficiently motivating scenario, I converted the roaches example from the vignette into JAGS. I also put constrained priors on the scale parameters for the sake of learning what to do with the unconstrain_pars, log_prob_upars, and log_lik_i_upars arguments of loo_moment_match(). (Is it even appropriate to consider "unconstrained parameters" without HMC?)

    library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
    
    data(roaches, package = "rstanarm")
    roaches$roach1 <- sqrt(roaches$roach1)
    x <- roaches[, c("roach1", "treatment", "senior")]
    data <- list(
        N = nrow(x),
        K = ncol(x),
        x = as.matrix(x),
        y = roaches$y,
        offset = log(roaches[,"exposure2"])
    )
    
    model_text <- "
model {
  for (n in 1:N) {
    y[n] ~ dpois(exp(inprod(x[n,], beta) + intercept + offset[n]))
  }
  for (k in 1:K) {
    beta[k] ~ dnorm(0, 1 / (scale_beta * scale_beta))
  }
  intercept ~ dnorm(0, 1 / (scale_alpha * scale_alpha))
  scale_beta ~ dunif(0, 10)
  scale_alpha ~ dnorm(0, 10) T(0,)
  for (n in 1:N) {
    log_lik[n] <- log(dpois(y[n], exp(inprod(x[n,], beta) + intercept + offset[n])))
  }
}
"
    file <- tempfile()
    writeLines(model_text, file)
    
    tmp <- capture.output({
        model <- rjags::jags.model(
            file = file,
            data = data,
            n.chains = 4,
            n.adapt = 2e3
        )
        stats::update(model, n.iter = 2e3, quiet = TRUE)
        coda <- rjags::coda.samples(
            model = model,
            variable.names = c(
                "beta",
                "intercept",
                "scale_beta",
                "scale_alpha",
                "log_lik"
            ),
            n.iter = 4e3
        )
    })
    
    fit <- posterior::as_draws_df(coda)
    print(fit) # This is the model fit object I can work with.
#> # A draws_df: 4000 iterations, 4 chains, and 268 variables
#>    beta[1] beta[2] beta[3] intercept log_lik[1] log_lik[2] log_lik[3]
#> 1     0.16   -0.55   -0.30       2.5        -19        -16       -2.1
#> 2     0.16   -0.55   -0.29       2.5        -19        -16       -2.1
#> 3     0.16   -0.52   -0.27       2.5        -17        -14       -2.1
#> 4     0.16   -0.55   -0.38       2.5        -16        -14       -2.1
#> 5     0.16   -0.56   -0.25       2.5        -16        -14       -2.1
#> 6     0.16   -0.59   -0.26       2.5        -18        -15       -2.1
#> 7     0.16   -0.60   -0.25       2.5        -19        -16       -2.0
#> 8     0.16   -0.58   -0.34       2.5        -18        -16       -2.0
#> 9     0.16   -0.58   -0.30       2.5        -18        -15       -2.1
#> 10    0.16   -0.60   -0.33       2.5        -17        -15       -2.1
#>    log_lik[4]
#> 1        -2.2
#> 2        -2.2
#> 3        -2.3
#> 4        -2.2
#> 5        -2.2
#> 6        -2.2
#> 7        -2.2
#> 8        -2.2
#> 9        -2.2
#> 10       -2.2
#> # ... with 15990 more draws, and 260 more variables
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
    
    # Convergence looks okay.
    fit %>%
        select(starts_with(c("beta", "intercept", "scale"))) %>%
        posterior::summarize_draws() %>%
        print()
#> Warning: Dropping 'draws_df' class as required metadata was removed.
#> # A tibble: 6 × 10
#>   variable      mean median      sd     mad     q5    q95  rhat ess_bulk ess_t…¹
#>   <chr>        <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <dbl>    <dbl>   <dbl>
#> 1 beta[1]      0.161  0.161 0.00193 0.00194  0.158  0.164  1.00    1704.   2993.
#> 2 beta[2]     -0.566 -0.566 0.0248  0.0246  -0.607 -0.524  1.00    3388.   5597.
#> 3 beta[3]     -0.312 -0.312 0.0334  0.0335  -0.368 -0.259  1.00    5957.   8365.
#> 4 intercept    2.52   2.52  0.0260  0.0260   2.48   2.56   1.00    1339.   2777.
#> 5 scale_alpha  0.905  0.889 0.156   0.153    0.673  1.19   1.00    9336.   8030.
#> 6 scale_beta   0.816  0.569 0.835   0.298    0.272  2.20   1.00    2056.    864.
#> # … with abbreviated variable name ¹​ess_tail
    
    # LOO without the moment matching correction is straightforward.
    log_lik <- as.matrix(dplyr::select(fit, tidyselect::starts_with("log_lik")))
#> Warning: Dropping 'draws_df' class as required metadata was removed.
    r_eff <- loo::relative_eff(x = log_lik, chain_id = fit$.chain)
    loo <- loo::loo(x = log_lik, r_eff = r_eff)
#> Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
    
    # But we get high Pareto k values.
    print(loo)
#> 
#> Computed from 16000 by 262 log-likelihood matrix
#> 
#>          Estimate     SE
#> elpd_loo  -5462.1  696.5
#> p_loo       261.3   57.6
#> looic     10924.3 1393.0
#> ------
#> Monte Carlo SE of elpd_loo is NA.
#> 
#> Pareto k diagnostic values:
#>                          Count Pct.    Min. n_eff
#> (-Inf, 0.5]   (good)     239   91.2%   537       
#>  (0.5, 0.7]   (ok)         9    3.4%   76        
#>    (0.7, 1]   (bad)        7    2.7%   11        
#>    (1, Inf)   (very bad)   7    2.7%   1         
#> See help('pareto-k-diagnostic') for details.
    
    # How do I use loo_moment_match() in this situation?
    # loo::loo_moment_match(
    #   x = fit,
    #   post_draws = function(x) as.matrix(x),
    #   log_lik_i = function(x, i) x[[sprintf("log_lik[%s]", i)]],
    #   unconstrain_pars = "???", # Do we even need to consider the unconstrained space for non-HMC MCMC?
    #   log_prob_upars = "???", # Here is where I start to get lost.
    #   log_lik_i_upars = "???" # Same here.
    # )

Created on 2022-12-01 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23)
#>  os       macOS Big Sur ... 10.16
#>  system   x86_64, darwin17.0
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       America/Indiana/Indianapolis
#>  date     2022-12-01
#>  pandoc   2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package        * version date (UTC) lib source
#>  abind            1.4-5   2016-07-21 [1] CRAN (R 4.2.0)
#>  assertthat       0.2.1   2019-03-21 [1] CRAN (R 4.2.0)
#>  backports        1.4.1   2021-12-13 [1] CRAN (R 4.2.0)
#>  checkmate        2.1.0   2022-04-21 [1] CRAN (R 4.2.0)
#>  cli              3.4.1   2022-09-23 [1] CRAN (R 4.2.0)
#>  coda             0.19-4  2020-09-30 [1] CRAN (R 4.2.0)
#>  colorspace       2.0-3   2022-02-21 [1] CRAN (R 4.2.0)
#>  DBI              1.1.3   2022-06-18 [1] CRAN (R 4.2.0)
#>  digest           0.6.30  2022-10-18 [1] CRAN (R 4.2.0)
#>  distributional   0.3.1   2022-09-02 [1] CRAN (R 4.2.0)
#>  dplyr          * 1.0.10  2022-09-01 [1] CRAN (R 4.2.0)
#>  evaluate         0.18    2022-11-07 [1] CRAN (R 4.2.0)
#>  fansi            1.0.3   2022-03-24 [1] CRAN (R 4.2.0)
#>  farver           2.1.1   2022-07-06 [1] CRAN (R 4.2.0)
#>  fastmap          1.1.0   2021-01-25 [1] CRAN (R 4.2.0)
#>  fs               1.5.2   2021-12-08 [1] CRAN (R 4.2.0)
#>  generics         0.1.3   2022-07-05 [1] CRAN (R 4.2.0)
#>  ggplot2          3.4.0   2022-11-04 [1] CRAN (R 4.2.0)
#>  glue             1.6.2   2022-02-24 [1] CRAN (R 4.2.0)
#>  gtable           0.3.1   2022-09-01 [1] CRAN (R 4.2.0)
#>  highr            0.9     2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools        0.5.3   2022-07-18 [1] CRAN (R 4.2.0)
#>  knitr            1.41    2022-11-18 [1] CRAN (R 4.2.0)
#>  lattice          0.20-45 2021-09-22 [1] CRAN (R 4.2.1)
#>  lifecycle        1.0.3   2022-10-07 [1] CRAN (R 4.2.0)
#>  loo              2.5.1   2022-03-24 [1] CRAN (R 4.2.0)
#>  magrittr         2.0.3   2022-03-30 [1] CRAN (R 4.2.0)
#>  matrixStats      0.63.0  2022-11-18 [1] CRAN (R 4.2.0)
#>  munsell          0.5.0   2018-06-12 [1] CRAN (R 4.2.0)
#>  pillar           1.8.1   2022-08-19 [1] CRAN (R 4.2.0)
#>  pkgconfig        2.0.3   2019-09-22 [1] CRAN (R 4.2.0)
#>  posterior        1.3.1   2022-09-06 [1] CRAN (R 4.2.0)
#>  purrr            0.3.5   2022-10-06 [1] CRAN (R 4.2.0)
#>  R.cache          0.16.0  2022-07-21 [1] CRAN (R 4.2.0)
#>  R.methodsS3      1.8.2   2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo             1.25.0  2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils          2.12.2  2022-11-11 [1] CRAN (R 4.2.0)
#>  R6               2.5.1   2021-08-19 [1] CRAN (R 4.2.0)
#>  reprex           2.0.2   2022-08-17 [1] CRAN (R 4.2.0)
#>  rjags            4-13    2022-04-19 [1] CRAN (R 4.2.0)
#>  rlang            1.0.6   2022-09-24 [1] CRAN (R 4.2.0)
#>  rmarkdown        2.18    2022-11-09 [1] CRAN (R 4.2.0)
#>  rstudioapi       0.14    2022-08-22 [1] CRAN (R 4.2.0)
#>  scales           1.2.1   2022-08-20 [1] CRAN (R 4.2.0)
#>  sessioninfo      1.2.2   2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi          1.7.8   2022-07-11 [1] CRAN (R 4.2.0)
#>  stringr          1.4.1   2022-08-20 [1] CRAN (R 4.2.0)
#>  styler           1.8.1   2022-11-07 [1] CRAN (R 4.2.0)
#>  tensorA          0.36.2  2020-11-19 [1] CRAN (R 4.2.0)
#>  tibble           3.1.8   2022-07-22 [1] CRAN (R 4.2.0)
#>  tidyselect       1.2.0   2022-10-10 [1] CRAN (R 4.2.1)
#>  utf8             1.2.2   2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs            0.5.1   2022-11-16 [1] CRAN (R 4.2.0)
#>  withr            2.5.0   2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun             0.35    2022-11-16 [1] CRAN (R 4.2.0)
#>  yaml             2.3.6   2022-10-18 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Are there other convenient ways to make approximate LOO more robust?

Hi, you might have some luck with using the generic moment matching functions from https://github.com/topipa/iwmm
You'll need to manually specify the target function or importance weight function but it should work on a matrix object.

Thanks, @n-kall. Is iwmm a generic implementation of https://mc-stan.org/loo/articles/loo2-moment-matching.html, or is the underlying statistical method itself different too?

Yes, it is the same underlying mechanism, just generic (i.e. not tied to importance weights for leave-one-out posteriors). Given a log_ratio_fun, the moment_match function will return transformed draws and importance weights (and Pareto-k diagnostic values). k_threshold = 0.7 and split = TRUE would match the loo_moment_match defaults.

If you want to use it for the leave-one-out case, the log_ratio_fun should be a function that returns the negative log likelihood of the left-out observation. See the tests for an example. You'd likely need to wrap it in a loop (for each observation) and use the resulting draws+weights to calculate the elpd or other metrics you're interested in.