Add tensorboard callback
Closed this issue · 0 comments
The goal here is to add a callback that logs the training and the validation error (if the latter exists) and is then displayed in the browser via tensorboard. The R package for tensorboard can be found here: https://github.com/mlverse/tfevents.
For the first implementation, we can use the torch_callback
helper function as defined here: https://mlr3torch.mlr-org.com/articles/callbacks.html
Once everything is working as expected, we can move away from this "syntactic sugar" and implemented it directly as an R6 Class has e.g. here: https://github.com/mlr-org/mlr3torch/blob/main/R/CallbackSetProgress.R. This is necessary to generate the proper documentation for the class.
The training and validation loss can be accessed via those two fields from the torch context:
Lines 59 to 60 in fef4cdb
The validation loss is only present when a validation task is set, so we need to handle both cases.
Another open question is which configuration options we want for the callback.
We can also look a bit how it is implemented in keras: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard although we don't need to get the whole feature set, at least not in the first iteration.