sahirbhatnagar/casebase

error in absoluteRisk for family glmnet

Closed this issue · 2 comments

There's an error when trying to calculate absolute risks for family="glmnet". It occurs when there is at least one categorical variable. This is caused when creating the newdata matrix in hazard_estimation. I think it's a simple fix, as shown below. I should also note that the newer version of glmnet has a glmnet::prepareX function which might be useful here. It converts a data.frame into an analysis ready matrix (converting all categorical predictors using "one-hot" encoding).

pacman::p_load(casebase)
data("brcancer")
mod_cb_glmnet <- fitSmoothHazard(cens ~ estrec*time + 
                                     horTh + 
                                     age + 
                                     menostat + 
                                     tsize + 
                                     tgrade + 
                                     pnodes + 
                                     progrec,
                                 data = brcancer,
                                 time = "time", 
                                 ratio = 1, 
                                 family = "glmnet")

# these all gives errors
absoluteRisk(object = mod_cb_glmnet)
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90
absoluteRisk(object = mod_cb_glmnet, newdata = brcancer)
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90
absoluteRisk(object = mod_cb_glmnet, newdata = brcancer[1:2,])
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90
absoluteRisk(object = mod_cb_glmnet, newdata = brcancer[1:2,], time = c(500, 2000))
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90
absoluteRisk(object = mod_cb_glmnet, newdata = "typical", time = c(500, 2000))
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90
absoluteRisk(object = mod_cb_glmnet, time = c(500, 2000))
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90

# this is what casebase is doing internally in hazard_estimation
formula_pred <- formula(delete.response(terms(mod_cb_glmnet$formula)))
newdata_matrix <- model.matrix(formula_pred, brcancer)
newdata_matrix <- newdata_matrix[,which(colnames(newdata_matrix) != "(Intercept)")]

# this is missing horThno column
head(newdata_matrix)
#>   estrec time horThyes age menostatPost tsize      tgrade.L   tgrade.Q pnodes
#> 1     66 1814        0  70            1    21 -7.850462e-17 -0.8164966      3
#> 2     77 2018        1  56            1    12 -7.850462e-17 -0.8164966      7
#> 3    271  712        1  58            1    35 -7.850462e-17 -0.8164966      9
#> 4     29 1807        1  59            1    17 -7.850462e-17 -0.8164966      4
#> 5     65  772        0  73            1    35 -7.850462e-17 -0.8164966      1
#> 6     13  448        0  32            0    57  7.071068e-01  0.4082483     24
#>   progrec estrec:time
#> 1      48      119724
#> 2      61      155386
#> 3      52      192952
#> 4      60       52403
#> 5      26       50180
#> 6       0        5824
predict(mod_cb_glmnet, newdata_matrix, s = "lambda.min", newoffset = 0)
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90

# inspection of glmnet coefficients, there is a coef for horThno
# which is causing the mismatch in the matrix multiplcation of x %*% beta
coef(mod_cb_glmnet)
#> 13 x 1 sparse Matrix of class "dgCMatrix"
#>                          1
#> (Intercept)  -8.151282e+00
#> estrec        .           
#> time          5.762542e-05
#> horThno       .           
#> horThyes      .           
#> age           .           
#> menostatPost  .           
#> tsize         .           
#> tgrade.L      .           
#> tgrade.Q      .           
#> pnodes        5.911265e-02
#> progrec      -4.984255e-04
#> estrec:time   .

# this works, and is actually what is used in casebase:::cv.glmnet.formula
newdata_matrix_2 <- model.matrix(update.formula(formula_pred, ~.-1), brcancer)
head(newdata_matrix_2)
#>   estrec time horThno horThyes age menostatPost tsize      tgrade.L   tgrade.Q
#> 1     66 1814       1        0  70            1    21 -7.850462e-17 -0.8164966
#> 2     77 2018       0        1  56            1    12 -7.850462e-17 -0.8164966
#> 3    271  712       0        1  58            1    35 -7.850462e-17 -0.8164966
#> 4     29 1807       0        1  59            1    17 -7.850462e-17 -0.8164966
#> 5     65  772       1        0  73            1    35 -7.850462e-17 -0.8164966
#> 6     13  448       1        0  32            0    57  7.071068e-01  0.4082483
#>   pnodes progrec estrec:time
#> 1      3      48      119724
#> 2      7      61      155386
#> 3      9      52      192952
#> 4      4      60       52403
#> 5      1      26       50180
#> 6     24       0        5824
head(predict(mod_cb_glmnet, newdata_matrix_2, s = "lambda.min", newoffset = 0))
#>           1
#> 1 -7.399919
#> 2 -7.260514
#> 3 -7.438852
#> 4 -7.651667
#> 5 -7.936725
#> 6 -5.825432

Created on 2020-06-23 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 3.6.2 (2019-12-12)
#>  os       Pop!_OS 19.10               
#>  system   x86_64, linux-gnu           
#>  ui       X11                         
#>  language en_US:en                    
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/Toronto             
#>  date     2020-06-23                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date       lib
#>  assertthat    0.2.1      2019-03-21 [1]
#>  backports     1.1.8      2020-06-17 [1]
#>  callr         3.4.3      2020-03-28 [1]
#>  casebase    * 0.2.1.9001 2020-06-23 [1]
#>  cli           2.0.2      2020-02-28 [1]
#>  codetools     0.2-16     2018-12-24 [4]
#>  colorspace    1.4-1      2019-03-18 [1]
#>  crayon        1.3.4      2017-09-16 [1]
#>  data.table    1.12.8     2019-12-09 [1]
#>  desc          1.2.0      2018-05-01 [1]
#>  devtools      2.2.2      2020-02-17 [1]
#>  digest        0.6.25     2020-02-23 [1]
#>  dplyr         0.8.5      2020-03-07 [1]
#>  ellipsis      0.3.1      2020-05-15 [1]
#>  evaluate      0.14       2019-05-28 [1]
#>  fansi         0.4.1      2020-01-08 [1]
#>  foreach       1.5.0      2020-03-30 [1]
#>  fs            1.3.2      2020-03-05 [1]
#>  ggplot2       3.3.2.9000 2020-06-23 [1]
#>  glmnet        4.0-2      2020-06-16 [1]
#>  glue          1.4.1      2020-05-13 [1]
#>  gtable        0.3.0      2019-03-25 [1]
#>  highr         0.8        2019-03-20 [1]
#>  htmltools     0.5.0      2020-06-16 [1]
#>  iterators     1.0.12     2019-07-26 [1]
#>  knitr         1.29       2020-06-23 [1]
#>  lattice       0.20-38    2018-11-04 [4]
#>  lifecycle     0.2.0      2020-03-06 [1]
#>  magrittr      1.5        2014-11-22 [1]
#>  Matrix        1.2-18     2019-11-27 [4]
#>  memoise       1.1.0      2017-04-21 [1]
#>  mgcv          1.8-31     2019-11-09 [4]
#>  munsell       0.5.0      2018-06-12 [1]
#>  nlme          3.1-143    2019-12-10 [4]
#>  pacman        0.5.1      2019-03-11 [1]
#>  pillar        1.4.4      2020-05-05 [1]
#>  pkgbuild      1.0.8      2020-05-07 [1]
#>  pkgconfig     2.0.3      2019-09-22 [1]
#>  pkgload       1.1.0      2020-05-29 [1]
#>  prettyunits   1.1.1      2020-01-24 [1]
#>  processx      3.4.2      2020-02-09 [1]
#>  ps            1.3.3      2020-05-08 [1]
#>  purrr         0.3.3      2019-10-18 [1]
#>  R6            2.4.1      2019-11-12 [1]
#>  Rcpp          1.0.4.6    2020-04-09 [1]
#>  remotes       2.1.1      2020-02-15 [1]
#>  rlang         0.4.6      2020-05-02 [1]
#>  rmarkdown     2.3        2020-06-18 [1]
#>  rprojroot     1.3-2      2018-01-03 [1]
#>  scales        1.1.1      2020-05-11 [1]
#>  sessioninfo   1.1.1      2018-11-05 [1]
#>  shape         1.4.4      2018-02-07 [1]
#>  stringi       1.4.6      2020-02-17 [1]
#>  stringr       1.4.0      2019-02-10 [1]
#>  survival      3.1-8      2019-12-03 [4]
#>  testthat      2.3.2      2020-03-02 [1]
#>  tibble        3.0.1      2020-04-20 [1]
#>  tidyselect    1.0.0      2020-01-27 [1]
#>  usethis       1.5.1      2019-07-04 [1]
#>  vctrs         0.3.1      2020-06-05 [1]
#>  VGAM          1.1-3      2020-04-28 [1]
#>  withr         2.2.0      2020-04-20 [1]
#>  xfun          0.15       2020-06-21 [1]
#>  yaml          2.2.1      2020-02-01 [1]
#>  source                                  
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  Github (sahirbhatnagar/casebase@e5bcd69)
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.0)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  Github (tidyverse/ggplot2@7d05fa3)      
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.0)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.1)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.1)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#> 
#> [1] /home/sahir/R/x86_64-pc-linux-gnu-library/3.6
#> [2] /usr/local/lib/R/site-library
#> [3] /usr/lib/R/site-library
#> [4] /usr/lib/R/library

For completeness, we should add a categorical variable to the tests. Here is what I suggest:

library(casebase)
#> Warning: replacing previous import 'data.table:::=' by 'ggplot2:::=' when
#> loading 'casebase'
#> See example usage at http://sahirbhatnagar.com/casebase/
library(data.table)

n = 100; alpha = 0.05

lambda_t0 <- 1
lambda_t1 <- 3

times <- c(rexp(n = n, rate = lambda_t0),
           rexp(n = n, rate = lambda_t1))
censor <- rexp(n = 2*n, rate = -log(alpha))

times_c <- pmin(times, censor)
event_c <- 1 * (times < censor)

DF <- data.frame("ftime" = times_c,
                 "event" = event_c,
                 "Z" = c(rep(0,n), rep(1,n)))
DT <- data.table("ftime" = times_c,
                 "event" = event_c,
                 "Z" = c(rep(0,n), rep(1,n)))

extra_vars <- matrix(rnorm(9 * n), ncol = 9)

#include categorical predictor
cat_predictor <- factor(sample(c("no","yes"), n, TRUE))

DT_ext <- cbind(DT, data.table(extra_vars,V10=cat_predictor))

formula_glmnet <- formula(paste(c("event ~ ftime", "Z",
                                  paste0("V", 1:10)),
                                collapse = " + "))
fitDT <- fitSmoothHazard(formula_glmnet, data = DT_ext, time = "ftime", family = "glmnet")
riskDT <- absoluteRisk(fitDT, time = 0.5)
#> Error in cbind2(1, newx) %*% nbeta: Cholmod error 'X and/or Y have wrong dimensions' at file ../MatrixOps/cholmod_sdmult.c, line 90

Created on 2020-06-23 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 3.6.2 (2019-12-12)
#>  os       Pop!_OS 19.10               
#>  system   x86_64, linux-gnu           
#>  ui       X11                         
#>  language en_US:en                    
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/Toronto             
#>  date     2020-06-23                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date       lib
#>  assertthat    0.2.1      2019-03-21 [1]
#>  backports     1.1.8      2020-06-17 [1]
#>  callr         3.4.3      2020-03-28 [1]
#>  casebase    * 0.2.1.9001 2020-06-23 [1]
#>  cli           2.0.2      2020-02-28 [1]
#>  codetools     0.2-16     2018-12-24 [4]
#>  colorspace    1.4-1      2019-03-18 [1]
#>  crayon        1.3.4      2017-09-16 [1]
#>  data.table  * 1.12.8     2019-12-09 [1]
#>  desc          1.2.0      2018-05-01 [1]
#>  devtools      2.2.2      2020-02-17 [1]
#>  digest        0.6.25     2020-02-23 [1]
#>  dplyr         0.8.5      2020-03-07 [1]
#>  ellipsis      0.3.1      2020-05-15 [1]
#>  evaluate      0.14       2019-05-28 [1]
#>  fansi         0.4.1      2020-01-08 [1]
#>  foreach       1.5.0      2020-03-30 [1]
#>  fs            1.3.2      2020-03-05 [1]
#>  ggplot2       3.3.2.9000 2020-06-23 [1]
#>  glmnet        4.0-2      2020-06-16 [1]
#>  glue          1.4.1      2020-05-13 [1]
#>  gtable        0.3.0      2019-03-25 [1]
#>  highr         0.8        2019-03-20 [1]
#>  htmltools     0.5.0      2020-06-16 [1]
#>  iterators     1.0.12     2019-07-26 [1]
#>  knitr         1.29       2020-06-23 [1]
#>  lattice       0.20-38    2018-11-04 [4]
#>  lifecycle     0.2.0      2020-03-06 [1]
#>  magrittr      1.5        2014-11-22 [1]
#>  Matrix        1.2-18     2019-11-27 [4]
#>  memoise       1.1.0      2017-04-21 [1]
#>  mgcv          1.8-31     2019-11-09 [4]
#>  munsell       0.5.0      2018-06-12 [1]
#>  nlme          3.1-143    2019-12-10 [4]
#>  pillar        1.4.4      2020-05-05 [1]
#>  pkgbuild      1.0.8      2020-05-07 [1]
#>  pkgconfig     2.0.3      2019-09-22 [1]
#>  pkgload       1.1.0      2020-05-29 [1]
#>  prettyunits   1.1.1      2020-01-24 [1]
#>  processx      3.4.2      2020-02-09 [1]
#>  ps            1.3.3      2020-05-08 [1]
#>  purrr         0.3.3      2019-10-18 [1]
#>  R6            2.4.1      2019-11-12 [1]
#>  Rcpp          1.0.4.6    2020-04-09 [1]
#>  remotes       2.1.1      2020-02-15 [1]
#>  rlang         0.4.6      2020-05-02 [1]
#>  rmarkdown     2.3        2020-06-18 [1]
#>  rprojroot     1.3-2      2018-01-03 [1]
#>  scales        1.1.1      2020-05-11 [1]
#>  sessioninfo   1.1.1      2018-11-05 [1]
#>  shape         1.4.4      2018-02-07 [1]
#>  stringi       1.4.6      2020-02-17 [1]
#>  stringr       1.4.0      2019-02-10 [1]
#>  survival      3.1-8      2019-12-03 [4]
#>  testthat      2.3.2      2020-03-02 [1]
#>  tibble        3.0.1      2020-04-20 [1]
#>  tidyselect    1.0.0      2020-01-27 [1]
#>  usethis       1.5.1      2019-07-04 [1]
#>  vctrs         0.3.1      2020-06-05 [1]
#>  VGAM          1.1-3      2020-04-28 [1]
#>  withr         2.2.0      2020-04-20 [1]
#>  xfun          0.15       2020-06-21 [1]
#>  yaml          2.2.1      2020-02-01 [1]
#>  source                                  
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  Github (sahirbhatnagar/casebase@e5bcd69)
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.0)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  Github (tidyverse/ggplot2@7d05fa3)      
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.0)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.1)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.1)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#>  CRAN (R 3.6.2)                          
#> 
#> [1] /home/sahir/R/x86_64-pc-linux-gnu-library/3.6
#> [2] /usr/local/lib/R/site-library
#> [3] /usr/lib/R/site-library
#> [4] /usr/lib/R/library

I just had another look at this. The issue is that fitSmoothHazard and estimate_hazard don't create the data matrix the same way, and this is independently of glmnet (maybe this was obvious to you but not to me). The other observation that escaped me at first is that, with glmnet you want as many dummy variables as levels (as opposed to basic linear regression, where you want one less dummy variable). For these reasons, it make sense to follow what glmnet::prepareX does.

One main difference between glmnet::prepareX and the solution above is how it handles ordered factors. By default, model.matrix uses polynomial contrasts, which explains the suffix you get with tgrade.L and tgrade.Q. On the other hand, glmnet::prepareX ignores the ordered nature and just uses one-hot encoding for all levels.

I guess we should make sure that the behaviour is consistent between casebase and glmnet.