Make `augment.merMod()` more consistent with `predict.merMod()` when using `newdata`
Opened this issue · 2 comments
This is related to #125, but I felt like it deserved its own issue.
The behaviour (and documentation) of augment.merMod()
when making predictions on new data could use some love. The current behaviour is inconsistent with predict.merMod()
and leads to unexpected results that can be misleading or unclear.
Here's a reprex covering some of the issue's I found. I think the function needs a rewrite to handle augmenting the original data used to fit the model differently from making predictions on new data. Perhaps dropping its dependence on broom::augment_columns()
given some of its behaviour (or at least adding some error checking for the cases where it should fail).
Regarding documentation, it isn't documented anywhere that you can use the re.form
argument with augment.merMod()
/broom::augment_columns()
; I tried it on a whim while trying to make predictions and it just happened to (partially) work.
library(tibble)
library(lme4)
#> Loading required package: Matrix
library(broom)
library(broom.mixed)
lmm1 <- lmer(Reaction ~ Days + (Days | Subject), sleepstudy)
# When you just want to augment the data used to fit the model everything is
# good and the results are what you'd expect. However, things go wrong once you
# want to make predictions with new data.
augment(lmm1)
#> # A tibble: 180 × 14
#> Reaction Days Subject .fitted .resid .hat .cooksd .fixed .mu .offset
#> <dbl> <dbl> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 250. 0 308 254. -4.10 0.229 0.00496 251. 254. 0
#> 2 259. 1 308 273. -14.6 0.170 0.0402 262. 273. 0
#> 3 251. 2 308 293. -42.2 0.127 0.226 272. 293. 0
#> 4 321. 3 308 313. 8.78 0.101 0.00731 283. 313. 0
#> 5 357. 4 308 332. 24.5 0.0910 0.0506 293. 332. 0
#> 6 415. 5 308 352. 62.7 0.0981 0.362 304. 352. 0
#> 7 382. 6 308 372. 10.5 0.122 0.0134 314. 372. 0
#> 8 290. 7 308 391. -101. 0.162 1.81 325. 391. 0
#> 9 431. 8 308 411. 19.6 0.219 0.106 335. 411. 0
#> 10 466. 9 308 431. 35.7 0.293 0.571 346. 431. 0
#> # … with 170 more rows, and 4 more variables: .sqrtXwt <dbl>, .sqrtrwt <dbl>,
#> # .weights <dbl>, .wtres <dbl>
# For context, first let's cover predict.merMod()'s behaviour. ----------------
# If you want to make predictions conditioned on random effects, you need to
# provide data for the random effects groups you want to make predictions for.
# For example, here we make predictions for Subjects 308 and 310 on Days 0-3.
# The resulting vector is the same length as `newdata`.
predict(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> 1 2 3 4 5 6 7 8
#> 253.6637 273.3299 292.9962 312.6624 212.4447 217.4631 222.4816 227.5000
# If we don't provide data for any Subjects, we get an error. This is expected
# and fine, but as we'll see later, augment.merMod() ignores this convention.
predict(lmm1, newdata = tibble(Days = 0:3))
#> Error in eval(predvars, data, env): object 'Subject' not found
# If we don't want to condition on the random effects, and instead want fixed
# effect predictions, we need to be explicit about that with `re.form = NA`.
# Here too the resulting vector is the same length as `newdata`.
predict(lmm1, newdata = tibble(Days = 0:3), re.form = NA)
#> 1 2 3 4
#> 251.4051 261.8724 272.3397 282.8070
# For more context, next let's cover augment_columns()'s behaviour ------------
# augment_columns() is a developer-facing function intended for use in the
# internals of augment methods. It is used as a starting point in
# augment.merMod(), then wrangled further later in the function. The wrangling
# causes issues later on.
# augment_columns() has consistent behaviour with predict.merMod() in some but
# not all cases. Here I've simply repeated the same three predict() examples
# from above.
# Consistent with predict.merMod(), results are what you'd expect.
augment_columns(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> # A tibble: 8 × 3
#> Days Subject .fitted
#> <int> <dbl> <dbl>
#> 1 0 308 254.
#> 2 1 308 273.
#> 3 2 308 293.
#> 4 3 308 313.
#> 5 0 310 212.
#> 6 1 310 217.
#> 7 2 310 222.
#> 8 3 310 227.
# No error thrown this time, even though no Subject data was provided. This
# appears to be the predictions for all subjects, with the Days vector recycled
# to the total number of observations. There is no mention of this, with no
# Subject column to cross-reference against, and the Days column now has no
# correspondence to the .fitted column.
augment_columns(lmm1, newdata = tibble(Days = 0:3))
#> # A tibble: 180 × 2
#> Days .fitted
#> <int> <dbl>
#> 1 0 254.
#> 2 1 273.
#> 3 2 293.
#> 4 3 313.
#> 5 0 332.
#> 6 1 352.
#> 7 2 372.
#> 8 3 391.
#> 9 0 411.
#> 10 1 431.
#> # … with 170 more rows
predict(lmm1, newdata = tibble(Days = 0:9, Subject = 308))
#> 1 2 3 4 5 6 7 8
#> 253.6637 273.3299 292.9962 312.6624 332.3287 351.9950 371.6612 391.3275
#> 9 10
#> 410.9937 430.6600
# It also doesn't matter if you're explicit about the `re.form`; it still doesn't
# throw an error.
augment_columns(lmm1, newdata = tibble(Days = 0:3), re.form = ~ (Days | Subject))
#> # A tibble: 180 × 2
#> Days .fitted
#> <int> <dbl>
#> 1 0 254.
#> 2 1 273.
#> 3 2 293.
#> 4 3 313.
#> 5 0 332.
#> 6 1 352.
#> 7 2 372.
#> 8 3 391.
#> 9 0 411.
#> 10 1 431.
#> # … with 170 more rows
# Consistent with predict.merMod(), results are what you'd expect.
augment_columns(lmm1, newdata = data.frame(Days = 0:3), re.form = NA)
#> # A tibble: 4 × 2
#> Days .fitted
#> <int> <dbl>
#> 1 0 251.
#> 2 1 262.
#> 3 2 272.
#> 4 3 283.
# Now let's look at augment.merMod()'s behaviour ------------------------------
## Predictions conditioned on random effects:
# This throws a warning due to some of the aforementioned wrangling that happens
# inside augment.merMod() after getting the data from augment_columns().
# Specifically, the `respCols` (.mu, .offset, etc.) that are getting bound to
# the augment_columns() data frame come from the original data used to fit the
# model, rather than the new data.
augment(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> Warning in indices[which(stats::complete.cases(original))] <- seq_len(nrow(x)):
#> number of items to replace is not a multiple of replacement length
#> # A tibble: 8 × 9
#> Days Subject .fitted .mu .offset .sqrtXwt .sqrtrwt .weights .wtres
#> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0 308 254. 254. 0 1 1 1 -4.10
#> 2 1 308 273. 273. 0 1 1 1 -14.6
#> 3 2 308 293. 293. 0 1 1 1 -42.2
#> 4 3 308 313. 313. 0 1 1 1 8.78
#> 5 0 310 212. 332. 0 1 1 1 24.5
#> 6 1 310 217. 352. 0 1 1 1 62.7
#> 7 2 310 222. 372. 0 1 1 1 10.5
#> 8 3 310 227. 391. 0 1 1 1 -101.
# As a consequence, the respCols don't actually correspond to the new data. For
# example, this is clear if you look at the .mu and .wtres columns above. All
# the values come from the first 8 values in the model, which in this cases
# means they actually all come from Subject 308. This is bad.
lmm1@resp$mu[1:8]
#> [1] 253.6637 273.3299 292.9962 312.6624 332.3287 351.9950 371.6612 391.3275
lmm1@resp$wtres[1:8]
#> [1] -4.103656 -14.625218 -42.195579 8.777359 24.523197 62.695136
#> [7] 10.542574 -101.178888
## Leaving Subject out of newdata:
# This has the same problems as augment_columns(). The respCols at least
# correspond to .fitted and .fixed now, but this should really throw an error
# instead.
augment(lmm1, newdata = tibble(Days = 0:3))
#> # A tibble: 180 × 9
#> Days .fitted .fixed .mu .offset .sqrtXwt .sqrtrwt .weights .wtres
#> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0 254. 251. 254. 0 1 1 1 -4.10
#> 2 1 273. 262. 273. 0 1 1 1 -14.6
#> 3 2 293. 272. 293. 0 1 1 1 -42.2
#> 4 3 313. 283. 313. 0 1 1 1 8.78
#> 5 0 332. 293. 332. 0 1 1 1 24.5
#> 6 1 352. 304. 352. 0 1 1 1 62.7
#> 7 2 372. 314. 372. 0 1 1 1 10.5
#> 8 3 391. 325. 391. 0 1 1 1 -101.
#> 9 0 411. 335. 411. 0 1 1 1 19.6
#> 10 1 431. 346. 431. 0 1 1 1 35.7
#> # … with 170 more rows
## Fixed effect predictions:
# Similar to the problem with the predictions conditioned on random effects,
# the `respCols` have no correspondence to the new data. This is obvious if
# you make predictions on different days.
augment(lmm1, newdata = data.frame(Days = 0:3), re.form = NA)
#> Warning in indices[which(stats::complete.cases(original))] <- seq_len(nrow(x)):
#> number of items to replace is not a multiple of replacement length
#> # A tibble: 4 × 8
#> Days .fitted .mu .offset .sqrtXwt .sqrtrwt .weights .wtres
#> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0 251. 254. 0 1 1 1 -4.10
#> 2 1 262. 273. 0 1 1 1 -14.6
#> 3 2 272. 293. 0 1 1 1 -42.2
#> 4 3 283. 313. 0 1 1 1 8.78
augment(lmm1, newdata = data.frame(Days = 6:9), re.form = NA)
#> Warning in indices[which(stats::complete.cases(original))] <- seq_len(nrow(x)):
#> number of items to replace is not a multiple of replacement length
#> # A tibble: 4 × 8
#> Days .fitted .mu .offset .sqrtXwt .sqrtrwt .weights .wtres
#> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 6 314. 254. 0 1 1 1 -4.10
#> 2 7 325. 273. 0 1 1 1 -14.6
#> 3 8 335. 293. 0 1 1 1 -42.2
#> 4 9 346. 313. 0 1 1 1 8.78
Created on 2023-05-30 with reprex v2.0.2
Thanks! Given the number of open issues and their heterogeneity I really think I need to go through and tag them with 'feature-request'/'enhancement' etc. so I can prioritize them and fix the ones that really need to be fixed (I would put this one in that category ... maybe I'll add an 'infelicity' tag [that's Bill Venables's neutral term for "it's not technically a bug but it's definitely bad behaviour"] that's just below 'bug' in priority ...)
Welcome! Haha, I like the infelicity tag. It's probably also worth checking if this issue applies to any of the other functions that rely on augment_columns()
.
Here's a very rough rewrite of augment.merMod()
, basically just adding some conditionals to the existing code. It might be as simple as something like this.
library(tibble)
library(lme4)
#> Loading required package: Matrix
library(broom)
library(broom.mixed)
lmm1 <- lmer(Reaction ~ Days + (Days | Subject), sleepstudy)
augment.merMod <- function(x, data = stats::model.frame(x), newdata, ...) {
# Augment the original data used to fit the model
if (missing(newdata)) {
# move rownames if necessary
newdata <- NULL
ret <- suppressMessages(augment_columns(x, data, newdata, se.fit = NULL, ...))
# add predictions with no random effects (population means)
predictions <- stats::predict(x, re.form = NA)
# some cases, such as values returned from nlmer, return more than one
# prediction per observation. Not clear how those cases would be tidied
if (length(predictions) == nrow(ret)) {
ret$.fixed <- predictions
}
# columns to extract from resp reference object
# these include relevant ones that could be present in lmResp, glmResp,
# or nlsResp objects
respCols <- c(
"mu", "offset", "sqrtXwt", "sqrtrwt", "weights",
"wtres", "gam", "eta"
)
cols <- lapply(respCols, function(cc) x@resp[[cc]])
names(cols) <- paste0(".", respCols)
## remove too-long fields and empty fields
n_vals <- vapply(cols,length,1L)
min_n <- min(n_vals[n_vals>0])
cols <- dplyr::bind_cols(cols[n_vals==min_n])
cols <- broom.mixed:::insert_NAs(cols, ret)
if (length(cols) > 0) {
ret <- dplyr::bind_cols(ret, cols)
}
return(broom.mixed:::unrowname(ret))
# Make predictions on new data
} else {
ret <- suppressMessages(augment_columns(x, data, newdata, se.fit = NULL, ...))
# Throw an error when re.form isn't specified, and there's no grouping
# variable in newdata. This is fragile but just intended for demonstration.
# Note: Can't use missing() since re.form comes from the ... args.
if (!hasArg(re.form) & ncol(stats::model.frame(x)) != ncol(ret)) {
stop("No data provided for grouping variable.")
}
# add predictions on newdata with no random effects (population means)
predictions <- stats::predict(x, newdata, re.form = NA)
# some cases, such as values returned from nlmer, return more than one
# prediction per observation. Not clear how those cases would be tidied
if (length(predictions) == nrow(ret)) {
ret$.fixed <- predictions
}
tibble::tibble(ret, .mu = NA, .offset = NA, etc. = NA)
}
}
augment(lmm1)
#> # A tibble: 180 × 14
#> Reaction Days Subject .fitted .resid .hat .cooksd .fixed .mu .offset
#> <dbl> <dbl> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 250. 0 308 254. -4.10 0.229 0.00496 251. 254. 0
#> 2 259. 1 308 273. -14.6 0.170 0.0402 262. 273. 0
#> 3 251. 2 308 293. -42.2 0.127 0.226 272. 293. 0
#> 4 321. 3 308 313. 8.78 0.101 0.00731 283. 313. 0
#> 5 357. 4 308 332. 24.5 0.0910 0.0506 293. 332. 0
#> 6 415. 5 308 352. 62.7 0.0981 0.362 304. 352. 0
#> 7 382. 6 308 372. 10.5 0.122 0.0134 314. 372. 0
#> 8 290. 7 308 391. -101. 0.162 1.81 325. 391. 0
#> 9 431. 8 308 411. 19.6 0.219 0.106 335. 411. 0
#> 10 466. 9 308 431. 35.7 0.293 0.571 346. 431. 0
#> # … with 170 more rows, and 4 more variables: .sqrtXwt <dbl>, .sqrtrwt <dbl>,
#> # .weights <dbl>, .wtres <dbl>
augment(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> # A tibble: 8 × 7
#> Days Subject .fitted .fixed .mu .offset etc.
#> <int> <dbl> <dbl> <dbl> <lgl> <lgl> <lgl>
#> 1 0 308 254. 251. NA NA NA
#> 2 1 308 273. 262. NA NA NA
#> 3 2 308 293. 272. NA NA NA
#> 4 3 308 313. 283. NA NA NA
#> 5 0 310 212. 251. NA NA NA
#> 6 1 310 217. 262. NA NA NA
#> 7 2 310 222. 272. NA NA NA
#> 8 3 310 227. 283. NA NA NA
augment(lmm1, newdata = tibble(Days = 0:3))
#> Error in augment.merMod(lmm1, newdata = tibble(Days = 0:3)): No data provided for grouping variable.
augment(lmm1, newdata = data.frame(Days = 0:3), re.form = NA)
#> # A tibble: 4 × 6
#> Days .fitted .fixed .mu .offset etc.
#> <int> <dbl> <dbl> <lgl> <lgl> <lgl>
#> 1 0 251. 251. NA NA NA
#> 2 1 262. 262. NA NA NA
#> 3 2 272. 272. NA NA NA
#> 4 3 283. 283. NA NA NA
augment(lmm1, newdata = data.frame(Days = 6:9), re.form = NA)
#> # A tibble: 4 × 6
#> Days .fitted .fixed .mu .offset etc.
#> <int> <dbl> <dbl> <lgl> <lgl> <lgl>
#> 1 6 314. 314. NA NA NA
#> 2 7 325. 325. NA NA NA
#> 3 8 335. 335. NA NA NA
#> 4 9 346. 346. NA NA NA
Created on 2023-05-30 with reprex v2.0.2