mlr-org/mlr3torch

Add tensorboard callback

Closed this issue · 0 comments

The goal here is to add a callback that logs the training and the validation error (if the latter exists) and is then displayed in the browser via tensorboard. The R package for tensorboard can be found here: https://github.com/mlverse/tfevents.

For the first implementation, we can use the torch_callback helper function as defined here: https://mlr3torch.mlr-org.com/articles/callbacks.html

Once everything is working as expected, we can move away from this "syntactic sugar" and implemented it directly as an R6 Class has e.g. here: https://github.com/mlr-org/mlr3torch/blob/main/R/CallbackSetProgress.R. This is necessary to generate the proper documentation for the class.

The training and validation loss can be accessed via those two fields from the torch context:

self$last_scores_train = structure(list(), names = character(0))
self$last_scores_valid = structure(list(), names = character(0))

The validation loss is only present when a validation task is set, so we need to handle both cases.
Another open question is which configuration options we want for the callback.
We can also look a bit how it is implemented in keras: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard although we don't need to get the whole feature set, at least not in the first iteration.