huggingface/peft

Reproducibility when using a model with batch norm

BenjaminBossan opened this issue · 0 comments

System Info

Latest version of PEFT

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

model_id = "microsoft/resnet-18"

@pytest.fixture
def image_processor():
    image_processor = AutoImageProcessor.from_pretrained(model_id)
    return image_processor

@pytest.fixture
def data(image_processor):
    dataset = load_dataset("huggingface/cats-image")
    image = dataset["test"]["image"][0]
    return image_processor(image, return_tensors="pt")

def test_model_with_batchnorm(tmp_path, data):
    torch.manual_seed(0)
    model = AutoModelForImageClassification.from_pretrained(model_id)
    config = LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"])
    model = get_peft_model(model, config)

    # record outputs before training
    model.eval()
    with torch.inference_mode():
        output_before = model(**data)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    batch_size = 4
    max_steps = 5 * batch_size
    labels = torch.zeros(1, 1000)
    labels[0, 283] = 1
    for i in range(0, max_steps, batch_size):
        optimizer.zero_grad()
        outputs = model(**data, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.inference_mode():
        output_after = model(**data)
    assert torch.isfinite(output_after.logits).all()
    atol, rtol = 1e-4, 1e-4
    # sanity check: model was updated
    assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)

    # check saving the model and loading it
    model.save_pretrained(tmp_path)
    del model
    torch.manual_seed(0)
    model = AutoModelForImageClassification.from_pretrained(model_id)
    model = PeftModel.from_pretrained(model, tmp_path).eval()
    with torch.inference_mode():
        output_loaded = model(**data)
    # THIS FAILS
    assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)

Expected behavior

After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.

One possible solution would be to try to include the buffers in the PEFT adapter, which is not very pretty. For this to work, we would need to have a way to identify buffers that were updated vs those that are static. If someone knows a way to achieve this, or has a better idea how to fix this, please let us know.

Edit: Best suggestion so far by @kashif: Check for the track_running_stats and if it's True, save the module's buffer. This will not cover all possible corner cases, but hopefully most.