# DNN prunning algorithm.

# Function explanation
prune_main.py
    |
    | --> prune_init(args, model) # initialize the prune function
    |
    | --> prune_update(epoch=0, batch_idx=0) # update ADMM variables (U, Z) during training
    |
    | --> prune_update_combined_loss(loss) # update the loss in ADMM training
    |
    | --> prune_harden() # A hard thresholding to set smallest values to 0 (this is done after the last epoch in training)
    |
    | --> prune_apply_masks() # Apply a binary mask on the weights (This is done in retraining to force unwanted weights to be 0)
    | 
    | --> prune_store_weights() # store the weights into files
    |
    | --> prune_print_sparsity(model) # print the sparsity of a model
    |
    | --> prune_update_learning_rate(optimizer, epoch, args) # update the learning rate of ADMM trainer. ADMM can use a different LR schedular because each time prune_update() is called, the ADMM is trying to solve a new problem.

  # Dependencies

  prune_main.py
    |
    | --> prune_base.py (Base class of prune) --> admm.py (ADMM pruning algorithm) --> multi_level_admm.py (Pruning while fixing part of the model)
    |           |
    |           | --> L_1_reweighted.py (L1 reweighted pruning algorithm (TO DO))
    |
    | --> retrain.py (retrain/fine-tune the network after hard prune) --> multi_level_retrain.py (retrain while fixing part of the model)


# Examples:
from prune_utils import *

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',help='path to dataset')
prune_parse_arguments(parser)
args = parser.parse_args()

def main():
    # define dataloader
    trainDataLoader = torch.utils.data.DataLoader(...)
    testDataLoader = torch.utils.data.DataLoader(...)

    # define DNN model
    model = ... 
    model = torch.nn.DataParallel(model).cuda() # if using multiple GPU

    # define loss function
    criterion = ...


    # define optimizer
    optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.7)

    # ADMM PRUNE
    prune_init(args, model)

    prune_apply_masks() # if wanted to make sure the mask is applied in retrain

    prune_print_sparsity(model) # check sparsity before retrain

    for epoch in range(start_epoch,args.epoch):
        
        # ADMM PRUNE
        prune_update(epoch)

        scheduler.step()

        # ADMM PRUNE
        prune_update_learning_rate(optimizer, epoch, args)

        for batch_id, data in enumerate(trainDataLoader):
            model.train()

            output = model(input)

            loss = criterion(...) # regular loss, i.e., cross-entropy, mse, ... 

            # ADMM PRUNE
            loss = prune_update_loss(loss)

            loss.backward()
            optimizer.step()

            # ADMM PRUNE
            prune_apply_masks()

        if epoch == args.epoch - 1:
            # ADMM PRUNE
            prune_harden()

        # save the model
        save_path = 'path_name.pth'
            state = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, save_path)


# Training scripts

# regular train
python3 imagenet_main.py <data_folder> --arch=resnet50 --worker=16 --batch-size=256 --gpu_id=1 --epochs=180 --learning-rate=0.1 --resume=existing_ckpt.pth

# admm
python3 imagenet_main.py <data_folder> --arch=resnet50 --worker=16 --batch-size=256 --gpu_id=1 --epochs=120 --learning-rate=0.1 --resume=pretrained.ckpt --sp-admm --sp-config-file=./profile/config_resnet50.yaml --sp-admm-update-epoch=30 --sp-admm-sparsity-type=irregular --sp-admm-lr=0.01

# retrain
python3 imagenet_main.py <data_folder> --arch=resnet50 --worker=16 --batch-size=256 --gpu_id=1 --epochs=120 --learning-rate=0.001 --resume=hard_pruned.ckpt --sp-retrain  --sp-config-file=./profile/config_resnet50.yaml

# evaluate
python3 imagenet_main.py <data_folder> --arch=resnet50 --worker=16 --batch-size=256 --gpu_id=1 --evaluate --resume=existing_ckpt.pth

# 2:4 structured pruning using admm
python3 imagenet_main.py <data_folder> --arch=resnet50 --worker=16 --batch-size=256 --gpu_id=1 --epochs=120 --learning-rate=0.1 --resume=pretrained.ckpt --sp-admm --sp-config-file=./profile/config_resnet50.yaml --sp-admm-update-epoch=30 --sp-admm-lr=0.01 --sp-admm-sparsity-type=N:M-prune-pattern --sp-admm-select-number 2 --sp-admm-pattern-row-sub 1 --sp-admm-pattern-col-sub 4