This Repository contains code to train and evaluate all three DLGN models : Value Network(VN), Value Tensor(VT), Kernel.
training_methods.py/
: Contains code to train and evaluate all three DLGN models.data_gen.py/
: Contains code to generate synthetic data for training and evaluation.DLGN_VN.py/
: Contains code for the Value Network model.DLGN_VT.py/
: Contains code for the Value Tensor model.DLGN_Kernel.py/
: Contains code for the Kernel model.DLGN_enums.py/
: Contains enums used in the code./data/
: Contains synthetic datasets for training and evaluation, these are used for my experiments.
To train the models, import training_methods.py
and call the train_model
function with the following parameters:
data
: The dataset to train on. This is a dictionary. The dataset should have the following keys:train_data
: The training data.train_labels
: The training labels.test_data
: The testing data.test_labels
: The testing labels.
config
: The configuration for the model. These are dictionaries/namespaces and can be copy pasted from the file configs.txt. Appropriately set the values according to each model and the data. These help in setting the hyperparameters for the model.
This function returns the trained model and the training loss.
An example use case is present in train_models.ipynb.
device
: 'cuda:0'(Based on available devices)model_type
: ModelTypes.KERNELnum_data
: Number of training datapointsdim_in
: Input data dimensionwidth
: Width of a layerdepth
: Number of layersbeta
: Gating parameter, multiplied with the gating score before performing sigmoidalpha_init
: Initial values for alpha: ['zero','random']BN
: Enable batch norm, booleanfeat
: Decides how features are defined(composit/shallow): ['cf','sf']weight_decay
: Weight decay for the gating networktrain_method
: ['KernelTrainMethod.PEGASOS', 'KernelTrainMethod.SVC','KernelTrainMethod.GD']reg
: For ['KernelTrainMethod.PEGASOS', 'KernelTrainMethod.SVC'], use as regularization for fitting, has different meaning for bothloss_fn_type
: ['LossTypes.HINGE', 'LossTypes.CE']optimizer_type
: ['Optim.ADAM', 'Optim.SGD']gates_lr
: Is the learning rate for gating networkalpha_lr
: Only enabled for 'KernelTrainMethod.GD', used for alpha updatesepochs
: Number of epochs for trainingvalue_freq
: Frequency of value tensor updatesnum_iter
: For ['KernelTrainMethod.PEGASOS', 'KernelTrainMethod.SVC'], number of optimising iterationsthreshold
: Threshold for checking proximity of features to dths. Generally use 0.3use_wandb
: For using wandb, boolean
device
: 'cuda:0'(Based on available devices)model_type
: ModelTypes.VNnum_data
: Number of training datapointsdim_in
: Input data dimensionnum_hidden_nodes
: Should be a list of the form: [Width]*Depthbeta
: Gating parameter, multiplied with the gating score before performing sigmoidBN
: Enable batch norm, booleanmode
: "pwc"reg
: For ['KernelTrainMethod.PEGASOS', 'KernelTrainMethod.SVC'], use as regularization for fitting, has different meaning for bothloss_fn_type
: 'LossTypes.CE'optimizer_type
: ['Optim.ADAM', 'Optim.SGD']lr
: Is the learning rate for gating networklr_ratio
: Changes value network updates by a factor: value_lr = lr/lr_ratioepochs
: Number of epochs for traininguse_wandb
: For using wandb, boolean
device
: 'cuda:0'(Based on available devices)model_type
: ModelTypes.VTnum_data
: Number of training datapointsdim_in
: Input data dimensionnum_hidden_nodes
: Should be a list of the form: [Width]*Depthbeta
: Gating parameter, multiplied with the gating score before performing sigmoidvalue_scale
: Scales randn function for tensor initializationBN
: Enable batch norm, booleanmode
: "pwc"prod
: Determine how the gating network output is computed: ['op','ip']vt_fit
: ['KernelTrainMethod.PEGASOS', 'KernelTrainMethod.SVC','KernelTrainMethod.LOGISTIC','KernelTrainMethod.PEGASOSKERNEL', 'KernelTrainMethod.NPKSVC','KernelTrainMethod.LINEARSVC']reg
: For different fit methods, use as regularization for fitting, has different meaning so check the example configsfeat
: Decides how features are defined(composit/shallow): ['cf','sf']loss_fn_type
: ['LossTypes.HINGE', 'LossTypes.CE']optimizer_type
: ['Optim.ADAM', 'Optim.SGD']lr
: Is the learning rate for gating networkepochs
: Number of epochs for trainingvalue_freq
: Frequency of value tensor updatessave_freq
: Frequency of saving the model to diskuse_wandb
: For using wandb, boolean