HAL fails to return a matrix of predictions when `cv_select=FALSE`
nhejazi opened this issue · 2 comments
nhejazi commented
#79 introduced a regression in which support for providing predictions along a grid of lambda values instead returns a vector of apparently identical values. This is clear from the README
example, which used to return:
r$> set.seed(385971)
r$> packageVersion("hal9001")
[1] ‘0.2.7’
r$> n <- 100
r$> p <- 3
r$> x <- matrix(rnorm(n * p), n, p)
r$> y <- x[, 1] * sin(x[, 2]) + rnorm(n, mean = 0, sd = 0.2)
r$> hal_fit <- fit_hal(X = x, Y = y, cv_select = FALSE)
[1] "I'm sorry, Dave. I'm afraid I can't do that."
r$> hal_fit$times
user.self sys.self elapsed user.child sys.child
enumerate_basis 0.003 0 0.003 0 0
design_matrix 0.006 0 0.006 0 0
reduce_basis 0.000 0 0.000 0 0
remove_duplicates 0.004 0 0.004 0 0
lasso 0.028 0 0.028 0 0
total 0.041 0 0.042 0 0
r$> preds <- predict(hal_fit, new_data = x)
r$> preds
s0 s1 s2 s3 s4 s5 s6
[1,] 0.06356236 0.06478106 0.06559040 0.06310132 0.06072537 0.05844328 0.0562756
s7 s8 s9 s10 s11 s12
[1,] 0.05463778 0.05627511 0.05339973 0.05226898 0.051860713 0.05198880
s13 s14 s15 s16 s17 s18
[1,] 0.052159151 0.052403705 0.05170359 0.05074108 0.04971674 0.04872899
s19 s20 s21 s22 s23 s24
[1,] 0.04718589 0.04224513 0.03713774 0.03313715 0.029578804 0.02611531
s25 s26 s27 s28 s29 s30
[1,] 0.02220966 0.01861265 0.016745656 0.014962968 0.014016019 0.013150862
s31 s32 s33 s34 s35 s36
[1,] 0.0084326033 0.0016781764 -0.004379224 -0.009756481 -0.018208534 -0.027420644
s37 s38 s39 s40 s41 s42
[1,] -0.035692904 -0.042651339 -0.047731797 -0.057171332 -0.068810925 -0.0814902028
s43 s44 s45 s46 s47 s48
[1,] -0.094145457 -0.1063458006 -0.112258682 -0.1169236800 -0.122157452 -0.126232513
s49 s50 s51 s52 s53 s54
[1,] -0.128004295 -0.129774524 -0.129584798 -0.129109598 -0.128860797 -0.128431573
s55 s56 s57 s58 s59 s60
[1,] -0.128002734 -0.127860160 -1.276471e-01 -1.263560e-01 -0.124513722 -0.123259610
s61 s62 s63 s64 s65 s66
[1,] -0.122217608 -0.1212981615 -0.119173537 -0.117948634 -0.118493536 -0.118916254
s67 s68 s69 s70 s71 s72
[1,] -0.11921098 -0.119095745 -0.118599351 -0.116668209 -0.115855176 -0.115227862
s73 s74 s75 s76 s77 s78
[1,] -0.114295608 -0.113795088 -0.113272632 -0.113057443 -0.113127302 -0.112956354
s79 s80 s81 s82 s83 s84
[1,] -0.11290637 -0.114645294 -0.116617272 -0.118362758 -0.119785603 -0.121354928
s85 s86 s87 s88 s89 s90
[1,] -0.1219591307 -0.121072044 -0.1201841224 -0.119764167 -0.119247460 -0.118394374
s91 s92 s93 s94 s95 s96
[1,] -0.117574887 -0.117310636 -0.116898537 -0.116803168 -0.117258712 -0.117751223
s97 s98 s99
[1,] -0.118306951 -0.11876637 -0.119148360
[ reached getOption("max.print") -- omitted 99 rows ]
...but now returns
r$> set.seed(385971)
r$> packageVersion("hal9001")
[1] ‘0.3.0’
r$> n <- 100
r$> p <- 3
r$> x <- matrix(rnorm(n * p), n, p)
r$> y <- x[, 1] * sin(x[, 2]) + rnorm(n, mean = 0, sd = 0.2)
r$> hal_fit <- fit_hal(X = x, Y = y, yolo = TRUE, cv_select = FALSE)
[1] "I'm sorry, Dave. I'm afraid I can't do that."
r$> hal_fit$times
user.self sys.self elapsed user.child sys.child
enumerate_basis 0.188 0.000 0.187 0 0
design_matrix 0.010 0.000 0.010 0 0
reduce_basis 0.000 0.000 0.000 0 0
remove_duplicates 0.000 0.000 0.000 0 0
lasso 0.054 0.001 0.054 0 0
total 0.253 0.001 0.253 0 0
r$> preds <- predict(hal_fit, new_data = x)
r$> preds
[1] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[8] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[15] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[22] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[29] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[36] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[43] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[50] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[57] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[64] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[71] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[78] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[85] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[92] 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236 0.06356236
[99] 0.06356236 0.06356236
[ reached getOption("max.print") -- omitted 9900 entries ]
nhejazi commented
Specifically, this was introduced by 3a55aab#diff-1133b4ed618525dc31392ae5cf9ca97832d30b2a926d2c4d007e1d3166d56fa4R132