jku-vds-lab/paradime

Create TrainingPhase class and/or add "add_training_phase" method to ParametricDR

Closed this issue · 1 comments

This class/method should allow users to set all the parameters required for a training phase. These are:

  • number of epochs
  • sampling method
  • batch-wise relations (specified as a dict, so custom loss functions can acces them)
  • loss function (which gets passed everything [model, global relations, batch-wise relations, batch] and accesses parts through keys)

There are three possibilities for the API:

  1. specify phases at training time: dr.train(phases=<list of TrainingPhase objects>)
  2. specify phases in __init__: `dr = ParametricDR(..., training_phases=)
  3. add phases one by one before training dr.add_training_phase(**kwargs)

Choose one or allow multiple?