stephenjfox/Morph.py

how can I use this?

danFromTelAviv opened this issue ยท 13 comments

It's awesome you implemented this in pytorch.
I am really not sure how I can use this on my net.
I see there's a demo.py but unless I'm missing something obvious it doesn't really show how to use this package. Why are you running morph once if it doesn't do anything?
Thanks,
Dan

Same Problem here.
Realy don't understand what happens

The automated functional implementation isn't published yet.
If you got here from PyPi, the version should indicate clear instability.
If you looked at the Projects page it should be clear that this isn't ready for general use.

I will be plugging together the final bells and whistles in the next few days, unless life happens again. Then @Filos92 you should have your tool.

@danFromTelAviv I'm in the middle of updating the main branch to show active status. There will soon be a presentation I've put together that explains:

I'm kind of getting caught out in the open, because I haven't pushed changes in a while and Google decided to announce that they've been doing cool things in secret (for over a year) just a week ago. I thought I would get more time, but the hype wave is what it is...

@danFromTelAviv To more directly answer your question, Google's paper (and old implementation) targeted only FLOPS.
This work stresses the alternate route of regularizing model size, in a rapid-compression fashion. Many of the comments you'll see have that mindset.
Nevertheless, the goal is the same: get the benefits of MorphNet, but always reducing model parameter count. A future "major" release will include the alternative regularizers.

@stephenjfox
ok. thanks again for taking this on.
i was thinking as a quick simple alternative to add weights+sigmoid to every layer's output. i can add a small l1 loss that tries to press those weights to zero. i can weight the loss per layer depending on the number of params / ops. this would just be way easier to implement. what do you think ?

@danFromTelAviv That sounds like roughly what I'm doing already. Keep an eye out for the pull request ~Monday and see what you think.
One of the reasons I wanted this in the open was for conversations like this, so stay in touch

awesome! let me know when u post it please :)

this is what I hope is a a simple working example :
I am now training my actual net ( a much more complex version of this one below ). It looks like the pruning is happening at a reasonable amount per layer. I hope it will improve my net though once i expand ... haven't gotten to that yet.
What do you think ? Did I miss something ?

class SimpleExampleOriginal(nn.Module):
    def __init__(self, c_in, c_out):
        super(SimpleExampleOriginal, self).__init__()

        # Conv1d(in_channels, out_channels, kernel_size, stride)
        self.num_classes = c_out
        self.num_features_in = c_in
        self.num_filters_list = [c_in, 100, 100, 100, 100, 100]
        self.filter_len_list = [9] * (len(self.num_filters_list) - 1)
        self.conv_layers = nn.ModuleList([nn.Conv1d(self.num_filters_list[idx], self.num_filters_list[idx + 1],
                                                    self.filter_len_list[idx]) for idx in
                                          range(len(self.filter_len_list))])
        self.classifier = Conv1dPad(self.num_filters_list[-1], c_out, 1)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, target_logits=None, device=None):
        for idx, conv_layer in enumerate(self.conv_layers):
            x = conv_layer(x)
        log_probs_ctc = self.softmax(x)

        return log_probs_ctc

turns into:

class SimpleExample(nn.Module):
    def __init__(self, c_in, c_out):
        super(SimpleExample, self).__init__()

        # Conv1d(in_channels, out_channels, kernel_size, stride)
        self.num_classes = c_out
        self.num_features_in = c_in
        self.num_filters_list = [c_in, 100, 100, 100, 100, 100]
        self.filter_len_list = [9] * (len(self.num_filters_list) - 1)
        self.conv_layers = nn.ModuleList([nn.Conv1d(self.num_filters_list[idx], self.num_filters_list[idx + 1],
                                                    self.filter_len_list[idx]) for idx in
                                          range(len(self.filter_len_list))])
        # add this to any layer that you want to prune
        self.pruning = nn.ModuleList([nn.Parameter(torch.ones((1, self.num_filters_list[idx + 1], 1)))
                                     for idx in range(len(self.filter_len_list))])

        # this param should be calculated differently depending on the architecture of the net.
        self.max_num_params = sum([inc * outc * filter_len for inc, outc, filter_len in
                                   zip(self.num_filters_list[:-1], self.num_filters_list[1:],
                                       self.filter_len_list)])

        self.classifier = Conv1dPad(self.num_filters_list[-1], c_out, 1)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, target_logits=None, device=None):
        modulations_loss = 0
        percent_information_per_layer = []
        prev_layer_num_features = self.num_features_in
        for idx, (conv_layer, pruning_layer) in enumerate(zip(self.conv_layers, self.pruning)):
            x = conv_layer(x)
            x = x * pruning_layer
            modulations_loss += pruning_layer.abs().sum() * prev_layer_num_features * self.filter_len_list[idx]
            percent_information_per_layer.append(
                [pruning_layer.abs().mean(), (pruning_layer.abs() > 1e-2).float().mean()])
            prev_layer_num_features = (pruning_layer.abs() > 1e-2).sum()
        modulations_loss = modulations_loss / self.max_num_params * 50 * 4
        log_probs_ctc = self.softmax(x)

        return log_probs_ctc, modulations_loss

I actually wrote this code as a quick (non generic) fix to optimize one of my nets.
I think one way to make this more generic is to make it a wrapper for layers:

class PurningWrapper(nn.Module):
    def __init__(self, layer, c_in, c_out, num_params, prune_thresh=1e-2):  # or num_ops
        super(PurningWrapper, self).__init__()
        self.layer = layer
        self.c_in = c_in
        self.c_out = c_out
        self.num_params = num_params
        self.num_params_no_input_output = self.num_params / self.c_out / self.c_in
        self.pruning = nn.Conv1d(1, c_out, 1, bias=False)
        self.prune_thresh = 1e-2

    def forward(self, x, prev_layer_num_features_not_pruned):
        out = self.layer(x)
        dummy_ones = torch.ones((x.shape[0], 1, x.shape[2]), dtype=torch.float32).cuda()
        modulation = self.pruning(dummy_ones)
        out = out*modulation
        loss = modulation[0, :, 0].abs().sum() * prev_layer_num_features_not_pruned * self.num_params_no_input_output
        num_features_not_pruned = (modulation[0, :, 0].abs() > self.prune_thresh).sum()
        return out, loss, num_features_not_pruned

I feel like this would not work for everything but would work in many situations... We need to figure out some way to deal with the input shape and knowing which dim is the features dim as well.
To unit test, something like that is a bit easier...

  1. make sure that it does prune over time...
  2. we can split parts of it into one line functions and make sure that they give expected output when invoked with some inputs...

I have a buddy (@AvivSham) who is also interested in developing this a bit. We can all brainstorm together .... try to improve on this...

I would like to start with maybe a very simple example - like minst with 5 conv layers ... just to see that it's for sure working and that...

@danFromTelAviv That's ironically exactly what I had in local testing.

I must point out that a lot of what you structured out in PruningWrapper is done with sparsify and the other utilities. Programmatically computing those values (i.e. c_in, c_out, num_params, previous_layer_num_features_not_pruned) for the user is exactly what I was working on before I took a hiatus from the project.

Could you explain the purpose of modulation, dummy_ones, and perhaps why the aforementioned utilities don't meet the same need? You are obviously more fluent in PyTorch than I, so it may just be my ignorance. Nevertheless, your usage of 1-D Convolution for pruning is rather opaque and one of my goals for this project was a clear implementation.


Hopefully you don't find my ignorance or style off-putting. As you said, (in so few words) it's something you threw together. Maybe you, @danFromTelAviv , myself, and @AvivSham should get on a call to better see how our goals align.

Tbh I didnt look into your code in depth. but sounds like we are in the same mind set. Yeah - i would also like it to be more clear...
so basically i would imagine creating a few wrappers for different common layers ( conv2d ,conv1d , rnn step...etc ) and people wouldn't even know how it works but it just will- all they have to do is wrap the relevant layer by the relevant wrapper...
We can wrap that wrapper in a single wrapper like your code ( I think ) does - the wrapper would find out what kind of layer its dealing with and go from there...
Now that I think about it we can also run across the net like you ( I think again ) are doing in your code and find all layers that we know how to deal with and wrap them automatically.

you are right - the trick with the ones+conv is really ugly - I switched it to a param set instead ( I edited my code above to reflect that - SimpleExample). In case its not clear what I'm doing there - modulation is a set of factors (one per feature) that modulate the previous layer's output. By using L1 loss ( abs().sum() ) I push some of the values in the modulation to 0 there by pruning some of the previous layer's outputs. This is just simpler, more generic, and more explicit compared to finding the BatchNorm layers and applying L1 to the gamma weights.

OK- I looked into sparse.py. That's a really nice way to go about it but how do you know when / where to apply it? Do you want to apply it to batchnorm layers during training? I couldn't quite gather how your code can be used. Can you please edit your demo.py so it shows how to use the code you currently have? Have you used it to optimize a net yet?

We are both in Israel which is about 7-10 hours off from the US . so maybe we can set up a google hangouts or something like that for your morning and our evening? which day works for you ( include a date so we know we are talking about the same day pelase )

btw - did you look at the tf implementation?
I couldn't really follow what they were doing there when I briefly looked at it , but it may be a good idea to copy the way they implemented it.