tlverse/hal9001

HAL fails to return a matrix of predictions when `cv_select=FALSE`

nhejazi opened this issue · 2 comments

#79 introduced a regression in which support for providing predictions along a grid of lambda values instead returns a vector of apparently identical values. This is clear from the README example, which used to return:

r$> set.seed(385971)
r$> packageVersion("hal9001")
[1] ‘0.2.7r$> n <- 100
r$> p <- 3
r$> x <- matrix(rnorm(n * p), n, p)
r$> y <- x[, 1] * sin(x[, 2]) + rnorm(n, mean = 0, sd = 0.2)
r$> hal_fit <- fit_hal(X = x, Y = y, cv_select = FALSE)
[1] "I'm sorry, Dave. I'm afraid I can't do that."
r$> hal_fit$times
                  user.self sys.self elapsed user.child sys.child
enumerate_basis       0.003        0   0.003          0         0
design_matrix         0.006        0   0.006          0         0
reduce_basis          0.000        0   0.000          0         0
remove_duplicates     0.004        0   0.004          0         0
lasso                 0.028        0   0.028          0         0
total                 0.041        0   0.042          0         0
r$> preds <- predict(hal_fit, new_data = x)

r$> preds
               s0          s1          s2          s3          s4          s5         s6
  [1,] 0.06356236  0.06478106  0.06559040  0.06310132  0.06072537  0.05844328  0.0562756
                s7          s8          s9         s10          s11         s12
  [1,]  0.05463778  0.05627511  0.05339973  0.05226898  0.051860713  0.05198880
                s13          s14         s15         s16         s17         s18
  [1,]  0.052159151  0.052403705  0.05170359  0.05074108  0.04971674  0.04872899
               s19         s20         s21         s22          s23         s24
  [1,]  0.04718589  0.04224513  0.03713774  0.03313715  0.029578804  0.02611531
               s25         s26          s27          s28          s29          s30
  [1,]  0.02220966  0.01861265  0.016745656  0.014962968  0.014016019  0.013150862
                 s31           s32          s33          s34          s35          s36
  [1,]  0.0084326033  0.0016781764 -0.004379224 -0.009756481 -0.018208534 -0.027420644
                s37          s38          s39          s40          s41           s42
  [1,] -0.035692904 -0.042651339 -0.047731797 -0.057171332 -0.068810925 -0.0814902028
                s43           s44          s45           s46          s47          s48
  [1,] -0.094145457 -0.1063458006 -0.112258682 -0.1169236800 -0.122157452 -0.126232513
                s49          s50          s51          s52          s53          s54
  [1,] -0.128004295 -0.129774524 -0.129584798 -0.129109598 -0.128860797 -0.128431573
                s55          s56           s57           s58          s59          s60
  [1,] -0.128002734 -0.127860160 -1.276471e-01 -1.263560e-01 -0.124513722 -0.123259610
                s61           s62          s63          s64          s65          s66
  [1,] -0.122217608 -0.1212981615 -0.119173537 -0.117948634 -0.118493536 -0.118916254
               s67          s68          s69          s70          s71          s72
  [1,] -0.11921098 -0.119095745 -0.118599351 -0.116668209 -0.115855176 -0.115227862
                s73          s74          s75          s76          s77          s78
  [1,] -0.114295608 -0.113795088 -0.113272632 -0.113057443 -0.113127302 -0.112956354
               s79          s80          s81          s82          s83          s84
  [1,] -0.11290637 -0.114645294 -0.116617272 -0.118362758 -0.119785603 -0.121354928
                 s85          s86           s87          s88          s89          s90
  [1,] -0.1219591307 -0.121072044 -0.1201841224 -0.119764167 -0.119247460 -0.118394374
                s91          s92          s93          s94          s95          s96
  [1,] -0.117574887 -0.117310636 -0.116898537 -0.116803168 -0.117258712 -0.117751223
                s97         s98          s99
  [1,] -0.118306951 -0.11876637 -0.119148360
 [ reached getOption("max.print") -- omitted 99 rows ]

...but now returns

r$> set.seed(385971)
r$> packageVersion("hal9001")
[1] ‘0.3.0r$> n <- 100
r$> p <- 3
r$> x <- matrix(rnorm(n * p), n, p)
r$> y <- x[, 1] * sin(x[, 2]) + rnorm(n, mean = 0, sd = 0.2)
r$> hal_fit <- fit_hal(X = x, Y = y, yolo = TRUE, cv_select = FALSE)
[1] "I'm sorry, Dave. I'm afraid I can't do that."
r$> hal_fit$times
                  user.self sys.self elapsed user.child sys.child
enumerate_basis       0.188    0.000   0.187          0         0
design_matrix         0.010    0.000   0.010          0         0
reduce_basis          0.000    0.000   0.000          0         0
remove_duplicates     0.000    0.000   0.000          0         0
lasso                 0.054    0.001   0.054          0         0
total                 0.253    0.001   0.253          0         0
r$> preds <- predict(hal_fit, new_data = x)

r$> preds
  [1] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
  [8] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [15] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [22] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [29] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [36] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [43] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [50] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [57] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [64] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [71] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [78] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [85] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [92] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
 [99] 0.06356236 0.06356236
 [ reached getOption("max.print") -- omitted 9900 entries ]

Resolved by #85.