Reproducibility when using a model with batch norm
BenjaminBossan opened this issue · 0 comments
System Info
Latest version of PEFT
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
model_id = "microsoft/resnet-18"
@pytest.fixture
def image_processor():
image_processor = AutoImageProcessor.from_pretrained(model_id)
return image_processor
@pytest.fixture
def data(image_processor):
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
return image_processor(image, return_tensors="pt")
def test_model_with_batchnorm(tmp_path, data):
torch.manual_seed(0)
model = AutoModelForImageClassification.from_pretrained(model_id)
config = LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"])
model = get_peft_model(model, config)
# record outputs before training
model.eval()
with torch.inference_mode():
output_before = model(**data)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
max_steps = 5 * batch_size
labels = torch.zeros(1, 1000)
labels[0, 283] = 1
for i in range(0, max_steps, batch_size):
optimizer.zero_grad()
outputs = model(**data, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
model.eval()
with torch.inference_mode():
output_after = model(**data)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
# check saving the model and loading it
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForImageClassification.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, tmp_path).eval()
with torch.inference_mode():
output_loaded = model(**data)
# THIS FAILS
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
Expected behavior
After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained
on the PeftModel
instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.
One possible solution would be to try to include the buffers in the PEFT adapter, which is not very pretty. For this to work, we would need to have a way to identify buffers that were updated vs those that are static. If someone knows a way to achieve this, or has a better idea how to fix this, please let us know.
Edit: Best suggestion so far by @kashif: Check for the track_running_stats
and if it's True
, save the module's buffer. This will not cover all possible corner cases, but hopefully most.