pytorch/vision

[RFC] How to handle BC breaking changes on Model weights or hyper-parameters

datumbox opened this issue ยท 3 comments

๐Ÿš€ Feature

In order to fix bugs we are sometimes forced to introduce BC breaking changes. While the process of such introductions is clear when it comes to code changes, it's not when it comes to model weights or hyper-parameters. Thus we should define when, why and how to introduce BC-breaking changes when it comes to model weights or model hyper-parameters.

Motivation

We have recently bumped to a few issues that motivate this. Here are a few examples:

  • On #2326 we discovered a bug in the initialization of some weights of all detection models. If we fix the bug on code, we should probably retrain the models. What happens if their accuracy improves? How do we make them available to our users?
  • How do we handle cases such as #2599 where in order to fix a bug we need to update the hyper-parameters of the model?

Approaches

There are quite a few different approaches for this:

  1. Replace the old parameters and Inform the community about the BC breaking changes. Example: #2942
    • Reasonable approach when the accuracy improvement is substantial or the effect on the model behaviour is negligible.
    • Keeps the code-base clean from workarounds and minimizes the number of weights we provide.
    • Can potentially cause issues to users who use transfer learning.
  2. Write code/workarounds to minimize the effect of the changes on existing models. Example: #2940
    • Reasonable approach when the changes lead to slight decrease in accuracy.
    • Minimizes the effects on users who used pre-trained models.
    • Introduces ugly workarounds on the code and increases the number of weights we provide.
  3. Introduce versioning on model weights:
    • Appropriate when introducing significant changes on the models.
    • Keeps the code-base clean from workarounds.
    • Forces us to maintain multiple versions of weights and model config.

It's worth discussing whether we want to adapt our approach depending on the characteristics of the problem or if we want to go with one approach for all cases. Moreover it's worth investigating whether we need to handle differently changes on weights vs changes on hyper-parameters used on inference.

cc @fmassa @cpuhrsch @vfdev-5 @mthrok

I think the future proof way of handing is the option 3. and having a factory function for versioning. I think we need at least two versioning for them, one for code (model version) and one for parameter (param version).

Model version can be added when there is significant change on the code, which is BC-breaking for the previous model.
Then, the question is narrowed down to what we want to do when there is a code change on model which is backward compatible with the existing parameter but would change the performance. If we want to be perfect on accessibility/availability, then we need to introduce the third versioning, which is like a minor version of each model, but I am not sure how often that happens.

class MyModel1(torch.nn.Module):
    def __init__(self, model_configurations):
        ...

class MyModel2(torch.nn.Module):
    def __init__(self, model_configurations):
        ...

def get_my_model(model_version, param_version):
    if model_version < 1:
        model = MyModel1()
        if param_version < 1:
            url = parameter_for_1_1
        elif parameter_version < 2:
            url = parameter_for_1_2
    elif model_version < 2:
        model = MyModel2()
        if param_version < 1:
            url = parameter_for_2_1
        elif parameter_version < 2:
            url = parameter_for_2_2
    else:
        raise ValueError('Unexpected model_version')

    param = _get_param(url)
    model.load_param(param)
    return model

Here is an example of parameters we don't currently store in state and it's unclear whether they should be considered part of the code or part of the params:

>>> from torch.nn.modules.dropout import Dropout
>>> x = Dropout(p=0.12345)
>>> x.state_dict()
OrderedDict()

It's worth discussing whether it makes sense to store these inside the state of the module.

On a general note I also agree with option 3: Since the weights affect the behavior of the model, you could, for the purposes of this discussion, think of them as code. With Python code we use versioning (git), so why not also with model weights? From that perspective we should rigorously version every model (whether a change is planned or not) and tag it with the version of the code that was used to generate it.

The Dropout probability parameter is an interesting one, since it doesn't affect inference, but will affect fine-tuning. We should make a decision as to the level of BC-compatability we provide.

Also, on another note, this affects the entirety of PyTorch domains and also projects such as torch serve.

From a technical perspective, a low-tech way of associating model weights to versions is by using md5 or such. We can further encode that map into the link we use to store the model.