EIDOSLAB/simplify

Simplify with Squeeze Net

pedrohenriqp opened this issue · 7 comments

Hi guys, congrats for the work, nice library.

I am trying to use Simplify with SquezeetNet.
I fine-tuned a pre-trained SqueezetNet from torchvision, fused some layers and pruned it with structured mode.

But when I use the library, it seems nothing changes.
For example, after structured prune (50%), the model has 361_216 parameters equals to zero. After simplify, it still has 361_216.

Do you have any advice to use the library with SquezeeNet?

@pedrohenriqp Hi, thank you for showing interest in our library. Could you please provide us with some code to reproduce the issue?

Thanks.

Hi @AndreaBrg. Yes, sure

import torch
from torchvision import models
import torch.nn.utils.prune as prune
from torch import nn

from simplify import simplify
import copy


def apply_prune(model_r):
  model = copy.deepcopy(model_r)
  for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
      #prune.ln_structured(module, name="weight", amount=0.3, n=1, dim=1)
      prune.l1_unstructured(module,name="weight", amount=0.8)
      prune.remove(module, 'weight')

  return model

model_raw = models.squeezenet1_0(pretrained=True) 
model_raw.eval()

model_pruned = apply_prune(model_raw)
dummy_input = torch.zeros(1, 3, 224, 224)
model_sim = copy.deepcopy(model_pruned)
simplified_model = simplify(model_sim, dummy_input)

Which give us the results:
Structured (30%)

- dense (model_raw) pruned (model_pruned) simplified (simplified_model)
total_parameters 1_248_424 1_248_424 1_248_424
total_non_zero 1_248_424 875_496 875_496
total_zero 0 372_928 372_928
size 5.01 MB 5.01 MB 5.01 MB

Unstructured (80%)

- dense (model_raw) pruned (model_pruned) simplified (simplified_model)
total_parameters 1_248_424 1_248_424 1_246_553
total_non_zero 1_248_424 252_865 252_815
total_zero 0 995_559 993_738
size 5.01 MB 5.01 MB 5.00 MB

As you can see, we have no difference in Structured prune, and a small one in Unstructured.

Thanks, I will check and let you know ASAP.

@pedrohenriqp sorry for the wait.
I believe you have mistyped the dim in prune.ln_structured(module, name="weight", amount=0.3, n=1, dim=1), should be dim=0.

This snipped should provide you with the expected results (sorry for the pretty rough prints):

import copy

import torch
import torch.nn.utils.prune as prune
from torch import nn
from torchvision import models

from simplify import simplify


def apply_prune(model_r):
    model = copy.deepcopy(model_r)
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
            prune.ln_structured(module, name="weight", amount=0.3, n=1, dim=0)
            # prune.l1_unstructured(module, name="weight", amount=0.8)
            prune.remove(module, 'weight')

    return model


def count_parameters(model):
    total_neurons = 0
    total_weights = 0
    total_biases = 0
    zeroed_neurons = 0
    zeroed_weights = 0
    zeroed_biases = 0

    for n, m in model.named_modules():

        if isinstance(m, nn.Conv2d):
            total_weights += torch.numel(m.weight)
            zeroed_weights += torch.numel(m.weight) - torch.count_nonzero(m.weight)

            if hasattr(m, "bias") and m.bias is not None:
                total_biases += torch.numel(m.bias)
                zeroed_biases += torch.numel(m.bias) - torch.count_nonzero(m.bias)

            total_neurons += m.weight.shape[0]
            zeroed_neurons += m.weight.shape[0] - torch.count_nonzero(torch.sum(torch.abs(m.weight), dim=(1, 2, 3)))

    return total_neurons, total_weights, total_biases, zeroed_neurons, zeroed_weights, zeroed_biases


if __name__ == '__main__':
    dummy_input = torch.zeros(1, 3, 224, 224)

    model_raw = models.squeezenet1_0(pretrained=True)
    model_raw.eval()
    model_stats = count_parameters(model_raw)
    print("MODEL RAW")
    print(f"total_neurons {model_stats[0]}\n"
          f"total_weights {model_stats[1]}\n"
          f"total_biases {model_stats[2]}\n"
          f"zeroed_neurons {model_stats[3]}\n"
          f"zeroed_weights {model_stats[4]}\n"
          f"zeroed_biases {model_stats[5]}")

    model_pruned = apply_prune(model_raw)
    model_stats = count_parameters(model_pruned)
    print("MODEL PRUNED")
    print(f"total_neurons {model_stats[0]}\n"
          f"total_weights {model_stats[1]}\n"
          f"total_biases {model_stats[2]}\n"
          f"zeroed_neurons {model_stats[3]}\n"
          f"zeroed_weights {model_stats[4]}\n"
          f"zeroed_biases {model_stats[5]}")

    model_sim = copy.deepcopy(model_pruned)
    simplified_model = simplify(model_sim, dummy_input)
    model_stats = count_parameters(simplified_model)
    print("MODEL SIMPLIFIED")
    print(f"total_neurons {model_stats[0]}\n"
          f"total_weights {model_stats[1]}\n"
          f"total_biases {model_stats[2]}\n"
          f"zeroed_neurons {model_stats[3]}\n"
          f"zeroed_weights {model_stats[4]}\n"
          f"zeroed_biases {model_stats[5]}")

Let me know if you have further questions!

Hi @AndreaBrg, thank you for your answer.
You are right. With dim=0, it is possible to observe the difference in numbers of parameters.

I also could observe a difference in the way we calculate the parameters.
I calculate based on:

def get_num_parameters(model, count_nonzero_only=False):
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements

In pruned_model, we got the same numbers, but in simplified_model, I could observe a big difference in total of parameters. After apply Simplify, are there any recommendation to count the parameters ?

@pedrohenriqp I believe that is due to the "expansion" parameters needed to perform the residual connections that you count iterating over all the parameters.
You can read the reasoning here (Section C1 of the Appendix). While this introduces a slight overhead, it should be unnoticeable given a sufficiently high sparsity rate.

@AndreaBrg Thank you for your help.