Max margin interval trees

TODOs

For the journal paper:

  • For what kinds of data does MMIT with the square loss learn a more accurate model than Interval-CART? Alex simulation figure: interval censored and left censored data, INTERVAL-CART does worse for more left censored data. What about for 0% interval censored data, plot test accuracy as a function of % left censored – hypothesis: INTERVAL-CART does fine for 50% left/50% right censored data.
  • Worst case for the squared loss (empirical + theoretical): Guillem’s example of quadratic worst case. (also show that if we reshuffle we don’t avoid it)
  • Uncensored outputs (survival datasets) – mention how to implement it, generalization of CART.
  • HEATMAP: Figure showing validation error as a function of regularization parameters, tree depth and margin size (either as a 2d heatmap or two line plots). Especially for the margin it would be nice to show that has a regularizing effect. (show for train and validation)
  • Related works: Cite Robin Girard, better description of Trafotree.
  • Smoother models: make a simple implementation of forests and see how smooth it is. If we get interesting results, we will pursue this in another paper.
  • mention submission to CRAN

Probably can leave for future work:

  • Motivate nonlinear patterns using Toby’s univariate experiment: train a linear model and a mmit for each feature individually and calculate the difference. The feature with the biggest difference should have a really non-linear pattern.
  • Forests.
  • Non-constant leaves: smooth after learning using interpolation between leaves. So we can use the same solver.

22 Nov 2017 (Alex)

  • Demonstrated that Interval-CART and MMIT do not give the same solution, even with all interval-censored data. A consequence of this is that the Interval-CART models tend to have much more splits, even though thoses splits are not useful in practice.
  • Toby thinks I should add a fixed test set and cross-validation (for hp selection) to the figure that shows how interval-cart fails with right/left-censored data.

21 Nov 2017 (Alex)

  • Increasing the margin hyperparameter will encourage the partitioning of leaves. So if the margin is increased, the other regularization hyperparameters must compensate. We should see this in the heatmaps that were proposed above.
  • Can we make a figure where we show the relationship between tree depth and margin, assuming that tree depth and min samples split are unbounded.
  • Interval-CART is more likely to partition leaves, since it does not understand the structure of the interval outputs. Example:
*-------|------|----
        |   *--|---o
        |   *--|--o
        |      |
        | CART | MMIT

In this case, Interval-CART would find it benefitial to split, but MMIT knows that this is a perfect leaf. Are Interval-CART trees generally more complex than MMIT trees? The interpretability of the models would make a good argument for favouring MMIT.

  • We know that MMIT finds the optimal solution for a variant of the squared loss that respects the structure of intervals. We can solve an equation to find the conditions under which the Inteval-CART solution is outside of the optimal region according to the MMIT solver.
  • In any given leaf, the minimum for the Interval-CART loss function is unique, whereas MMIT defines an optimal region. All we need to do is quantify when the ICART minimum is outside of this target region. I believe we can test this quickly using our dynamic programming algorithm.
  • We could then use this equation to scan the trees learned by Interval-CART and detect where it occurs for our data sets.
  • Assuming that all feature vectors are unique, fully grown MMIT and Interval-CART trees should yield the same (or very comparable) solutions. We can test this easily.

1 Sept 2017

Modified trafotreeCV function in trafotree.predictions.R: MSE rather than number of incorrect labels is used to select the best mc value using the validation set. For the CV models (without the 0.95 suffix) this modification results in lower test MSE:

figure-evaluate-predictions-one-H3K27ac-H3K4me3_TDHAM_BP_FPOP.png

Added trafotree.predictions.corrected.R: These models still may learn trees with nodes that have outputs of only one censoring type (e.g. all right censored outputs). In these nodes the trafotree function returns a unconverged estimate because the predicted value is undefined (e.g. Inf). However in this case we correct the predicted value by just taking the least extreme output limit (e.g. the largest lower limit).

27 July 2017

Torsten said on 2017-04-07 “You need mlt 0.1-3 and basefun 0.0-37 from CRAN. I updated trtf/DESCRIPTION accordingly. A fresh install of partykit devel and trtf would be nice as well.” My next email on 2017-06-09 (I think I found a bug), library(trtf)#svn up -r 734

Trying to figure out how to reproducbly get good results with TTreeIntOnly0.95 – just as accurate as IntervalRegressionCV in 2-fold CV on my laptop with these package versions.

> devtools::session_info()
Session info --------------------------------------------------------------------------------------
 setting  value                       
 version  R version 3.3.3 (2017-03-06)
 system   i686, linux-gnu             
 ui       X11                         
 language en_US                       
 collate  en_US.UTF-8                 
 tz       Canada/Eastern              
 date     2017-07-27                  

Packages ------------------------------------------------------------------------------------------
 package         * version    date       source                                  
 alabama           2015.3-1   2015-03-06 CRAN (R 3.3.3)                          
 base            * 3.3.3      2017-03-07 local                                   
 basefun         * 0.0-38     2017-06-16 local                                   
 BB                2015.12-1  2016-08-11 R-Forge (R 3.3.3)                       
 bitops            1.0-6      2013-08-17 CRAN (R 3.2.2)                          
 caTools           1.17.1     2014-09-10 CRAN (R 3.2.2)                          
 codetools         0.2-15     2016-10-05 CRAN (R 3.3.3)                          
 colorspace        1.3-0      2016-11-01 R-Forge (R 3.3.1)                       
 compiler          3.3.3      2017-03-07 local                                   
 data.table      * 1.10.5     2017-04-21 local                                   
 datasets        * 3.3.3      2017-03-07 local                                   
 devtools          1.13.2     2017-06-02 cran (@1.13.2)                          
 digest            0.6.12     2017-01-27 cran (@0.6.12)                          
 directlabels      2017.03.31 2017-04-01 Github (tdhock/directlabels@bc15a4f)    
 Formula           1.2-1      2016-02-17 R-Forge (R 3.3.3)                       
 future            1.4.0      2017-03-13 CRAN (R 3.3.3)                          
 geometry          0.3-6      2015-09-09 CRAN (R 3.2.2)                          
 ggplot2           2.1.0      2017-03-25 Github (faizan-khan-iit/ggplot2@5fb99d0)
 globals           0.10.0     2017-04-17 CRAN (R 3.3.3)                          
 graphics        * 3.3.3      2017-03-07 local                                   
 grDevices       * 3.3.3      2017-03-07 local                                   
 grid            * 3.3.3      2017-03-07 local                                   
 gtable            0.2.0      2016-02-26 CRAN (R 3.2.2)                          
 labeling          0.3        2014-08-23 R-Forge (R 3.2.2)                       
 lattice         * 0.20-34    2016-09-06 CRAN (R 3.3.3)                          
 libcoin         * 0.9-2      2017-04-05 R-Forge (R 3.3.3)                       
 listenv           0.6.0      2015-12-28 CRAN (R 3.3.3)                          
 magic             1.5-6      2013-11-20 CRAN (R 3.2.2)                          
 Matrix            1.2-8      2017-01-20 CRAN (R 3.3.3)                          
 memoise           1.1.0      2017-04-21 cran (@1.1.0)                           
 methods         * 3.3.3      2017-03-07 local                                   
 mlt             * 0.1-4      2017-06-16 local                                   
 munsell           0.4.3      2016-02-13 CRAN (R 3.2.2)                          
 mvtnorm         * 1.0-5      2016-02-02 R-Forge (R 3.2.2)                       
 namedCapture    * 2017.01.15 2017-04-29 Github (tdhock/namedCapture@1da425b)    
 numDeriv          2016.8-1   2016-08-21 R-Forge (R 3.3.3)                       
 orthopolynom      1.0-5      2013-02-04 CRAN (R 3.3.3)                          
 parallel          3.3.3      2017-03-07 local                                   
 partykit        * 1.2-0      2017-04-24 R-Forge (R 3.3.3)                       
 penaltyLearning * 2017.06.14 2017-06-22 local                                   
 plyr              1.8.4      2016-06-08 CRAN (R 3.2.2)                          
 polynom           1.3-9      2016-12-08 CRAN (R 3.3.3)                          
 quadprog          1.5-5      2013-04-17 CRAN (R 3.2.2)                          
 RColorBrewer    * 1.1-2      2014-12-07 CRAN (R 3.2.2)                          
 Rcpp              0.12.11    2017-05-22 cran (@0.12.11)                         
 RCurl           * 1.96-0     2016-08-07 local                                   
 requireGitHub     2017.03.16 2017-04-29 Github (tdhock/requireGitHub@5de2020)   
 RJSONIO         * 1.3-0      2014-07-28 CRAN (R 3.2.2)                          
 RSelenium       * 1.3.6      2016-11-09 Github (ropensci/RSelenium@22f06b9)     
 rstudioapi        0.6        2016-06-27 cran (@0.6)                             
 sandwich          2.3-4      2015-09-24 CRAN (R 3.2.2)                          
 scales            0.4.1      2016-11-09 CRAN (R 3.3.1)                          
 splines           3.3.3      2017-03-07 local                                   
 stats           * 3.3.3      2017-03-07 local                                   
 survival        * 2.41-3     2017-04-04 CRAN (R 3.3.3)                          
 tools             3.3.3      2017-03-07 local                                   
 trtf            * 0.2-1      2017-06-16 local                                   
 utils           * 3.3.3      2017-03-07 local                                   
 variables       * 0.0-30     2017-06-16 local                                   
 withr             1.0.2      2016-06-20 cran (@1.0.2)                           
 XML             * 3.99-0     2016-08-07 local                                   
 zoo               1.7-13     2015-12-15 R-Forge (R 3.2.2)                       
> 

16 June 2017

  • Emailed Torsten Hothorn, author of trafotree, with this bug description figure-trafotree-bug.R. Test error bigger than constant model for one data set.
  • He emailed back, with this code to solve the problem figure-trafotree-bug-response.R. Apparently when we set up the basis function for the output, we need to force the variance to be positive (positive slope, achieved via ui/ci arguments to as.basis), and we need to partition only on intercept (not slope/variance, achieved by parm=1 in trafotree). after redoing the test error figure, this does indeed fix the issue (trafotree is now learning as well as IntervalRegressionCV).

10 May 2017

Lots more algos and data sets, prediction accuracy figure http://bl.ocks.org/tdhock/raw/75751a85d2766cd43be4c36ee3fa58a1/

31 Mar 2017

Created a new filter when making data sets (observations must be greater than 13), which removed H3K36me3_TDH_other_joint (it had one fold with no negative labels, so we can’t compute AUC). So now there are 24 data sets in http://cbio.ensmp.fr/~thocking/data/penalty-learning-interval-regression-problems.tgz

figure-evaluate-predictions.R creates figures that compare the prediction accuracy. Right now I have just computed IntervalRegressionCV (linear model trained by minimizing squared hinge loss + L1 penalty) and constant (baseline model that just learns the constant penalty with minimum incorrect target intervals). We can see that IntervalRegressionCV does better in the tall data setting, and does about the same in the fat data setting.

29 Mar 2017

penaltyLearning.predictions.R creates predictions files for IntervalRegresionCV (linear model with squared hinge loss + L1 regularization).

To make it easy to compare models which we fit in either R or Python, I would suggest that we save model predictions in the following format. Create a separate directory called “predictions” inside of which is one sub-directory for each model. Each model sub-directory would have another sub-directory for each data set, in which there is a predictions.csv file (n x 1 – predicted values for each observation in 5-fold CV). For example

project/data/lymphoma.mkatayama/features.csv project/predictions/mmit.linear.hinge/lymphoma.mkatayama/predictions.csv project/predictions/mmit.squared.hinge/lymphoma.mkatayama/predictions.csv etc

26 penalty learning data sets created via data.sets.R (but one is less than 10 observations so we ignore it, leaving a total of 25 data sets). It creates a data directory with a subdirectory for each data set. Inside each of those are three files

  1. targets.csv is the n x 2 matrix of target intervals (outputs).
  2. features.csv is the n x p matrix of features (inputs). p is different for each data set.
  3. folds.csv is a n x 1 vector of fold IDs – for comparing model predictions using 5-fold cross-validation.
  • R pkg neuroblastoma + labels.
  • http://members.cbio.ensmp.fr/~thocking/neuroblastoma/signal.list.annotation.sets.RData this data contains many different types of microarrays – maybe create a data set that groups them all together?
  • thocking@guillimin:PeakSegFPOP/ChIPseq.wholeGenome.rds contains features + targets for genome wide ChIP-seq segmentation models (PeakSegFPOP and PeakSegJoint).
  • TODO copy 7 benchmark data sets from work computer. TO benchmark web page. Scripts to compute targets and features.

figure-data-set-sizes.R shows a summary of the dimensions of the 25 data sets, each of which should be treated as a separate learning problem.

  • the number of features varies from 26 to 259.
  • the number of observations varies from 13 to 3418.
  • some data sets are “fat” (n < p) and others are tall (p < n)
  • some data sets have more upper limits, others have more lower limits.
  • the penalty functions are for four types of segmentation models.

figure-data-set-sizes.png

22 March 2017

Interactive figure: select threshold on total cost curves, see updated prediction, margin and slack.

16 March 2017

figure-penaltyLearning.R visualizes cost as a function of feature value.