Adaptive Weight Decay Regularization based on Pytorch Framework
Create an object from AdaptiveWeightDecay
, for example:
model = AdaptiveWeightDecay(...)
To create an object set the inputs as:
- The network that you want to train (
VGG()
). - The module of loss function (
nn.MSELoss()
). - The optimizer (
torch.optim.Adam
). - The increasing factor of the coefficient of weight decay.
- The decreasing factor of the coefficient of weight decay.
- The overfitting treshold.
- The underfitting treshold.
After creating an object, you have to call fit
function to train on the dataset.
model.fit(train_loader, test_loader, max_epoch)
To run this scheme, you need to install numpy
, pytorch
, and tqdm
.