mlr3learnertemplate

Build Status

This packages provides a template for adding new learners for mlr3.

Creating custom learners is covered in the mlr3book.

Instructions

This repository is a minimal working package with the randomForest learner. Fork this repository and adapt the code to your learner.

1) Rename Files

Rename the following files to suit your learner:

  • R/LearnerClassifRandomForest.R
  • tests/testthat/test_classif_randomForest.R

(For regression use the prefix "Regr" instead of "Classif". For example learners see https://github.com/mlr-org/mlr3learners)

2) Edit R/LearnerYourLearner.R

  • Adapt the documentation to suit your learner.
  • Adapt names and the package, learner properties, etc.

The Parameter Set

makeIntegerLearnerParam()ParamInt$new()

makeLogicalLearnerParam()ParamLgl$new()

makeNumericVectorLearnerParam()ParamUty$new()

makeIntegerVectorLearnerParam()ParamUty$new()

makeUntypedLearnerParam()ParamUty$new()

makeNumericLearnerParam()ParamDbl$new()

The Training Method

The train_internal method is a function which takes a task as input and returns the trained model from your learner. It’s best to work with a test task, e.g.

task = mlr_tasks$get("iris")

You cannot copy paste the mlr source code and use it for the mlr3 learner. Know your learner. You only have information of the task your defined learner in this method.

The Prediction Method

Try to work with an example:

library(devtools)
load_all()
library(mlr3)
task = mlr_tasks$get("iris")
lrn = LearnerClassifRandomForest$new()
lrn$train(task = task)

Then you can write the prediction function and return a prediction object (PredictionClassif$new() or PredictionRegr$new()).

Learner Specific Methods

The randomForest function supports importance and oob_error. You need to write methods for that as they will be tested.

Test Your Function

If you run devtools::load_all() the function run_autotest() is available in your global environment. Change the code in tests/testthat/test_classif_randomForest.R to fit your learner. Make sure that at least the following is tested in the unit test tests/testthat/test_classif_your_learner.R:

learner = LearnerClassifRanger$new()
expect_learner(learner)
result = run_autotest(learner)
expect_true(result, info = result$error)

Check your package

If this runs, your learner should be fine:

devtools::check()