Create TrainingPhase class and/or add "add_training_phase" method to ParametricDR
Closed this issue · 1 comments
einbandi commented
This class/method should allow users to set all the parameters required for a training phase. These are:
- number of epochs
- sampling method
- batch-wise relations (specified as a dict, so custom loss functions can acces them)
- loss function (which gets passed everything [model, global relations, batch-wise relations, batch] and accesses parts through keys)
einbandi commented
There are three possibilities for the API:
- specify phases at training time:
dr.train(phases=<list of TrainingPhase objects>)
- specify phases in
__init__
: `dr = ParametricDR(..., training_phases=) - add phases one by one before training
dr.add_training_phase(**kwargs)
Choose one or allow multiple?