huggingface/peft

Loading trained peft model results in random adapter weights each time

ambisinister opened this issue · 4 comments

System Info

  • peft==0.10.0
  • transformers==4.28.2
  • accelerate==0.28.0

python 3.9.5
ubuntu 20.05.6

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from peft import (
    PeftConfig,
    PeftModel,
    AutoPeftModelForCausalLM,
    PeftModelForCausalLM,
    prepare_model_for_kbit_training,
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
import torch

def load_model(checkpoint_path, adapter=True, merge=True, use_config=False):
    model = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        load_in_4bit=True,
        torch_dtype=torch.float16,
        device_map={"":0},
    )

    if adapter:
        if not use_config:
            model = PeftModel.from_pretrained(model, checkpoint_path)
            if merge:
                model = model.merge_and_unload()
        else:
            peft_config = PeftConfig.from_pretrained(checkpoint_path)
            peft_config.init_lora_weights = False
            model.add_adapter(peft_config)
            model.enable_adapters()

    return model.state_dict(), model
    

def compare_weights(model1_weights, model2_weights):
    diffs = {}
    for key in model1_weights.keys():
        if torch.equal(model1_weights[key], model2_weights[key]):
            diffs[key] = "No change"
        else:
            diffs[key] = "Changed"
    return diffs

path1 = "./sanity_test_f16"
path2 = "./results-sanity-2/checkpoint-50"

merge = False
use_config = False

weights1, model1 = load_model(path1, adapter=True, merge=merge, use_config=use_config)
weights2, model2 = load_model(path2, adapter=True, merge=merge, use_config=use_config)

i = 0
differences = compare_weights(weights1, weights2)
for key, value in differences.items():
   if value == "Changed":
       print(f"{key}: {value}")
       i += 1
       if i == 6:
           break

tokenizer = AutoTokenizer.from_pretrained(path2)
input_ids = tokenizer.encode("### Instruction: Write code to reverse a linked list.",
                             return_tensors="pt")
input_ids = input_ids.to("cuda")
logits1 = model1(input_ids).logits
logits2 = model2(input_ids).logits

print(torch.isclose(logits1, logits2, atol=1e-5))

if merge:
    if use_config:
        print(weights1["model.layers.0.mlp.fc1.lora_A.default.weight"])
        print(weights2["model.layers.0.mlp.fc1.lora_A.default.weight"])
    else:
        print(weights1["model.layers.0.mlp.fc1.weight"])
        print(weights2["model.layers.0.mlp.fc1.weight"])
else:
    print(weights1["base_model.model.model.layers.0.mlp.fc1.lora_A.default.weight"])
    print(weights2["base_model.model.model.layers.0.mlp.fc1.lora_A.default.weight"])

In python REPL:

>>> exec(open('./sanity_check.py').read())                                                                                                                                     
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` 
argument instead.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.18it/s]
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.32it/s]
base_model.model.model.layers.0.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.0.mlp.fc2.lora_A.default.weight: Changed
base_model.model.model.layers.1.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.1.mlp.fc2.lora_A.default.weight: Changed
base_model.model.model.layers.2.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.2.mlp.fc2.lora_A.default.weight: Changed
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/home/ubuntu/.local/lib/python3.9/site-packages/bitsandbytes/nn/modules.py:391: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.
  warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
2024-04-24 19:22:10.481219: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-04-24 19:22:11.658579: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]], device='cuda:0')
tensor([[-0.0082, -0.0111, -0.0069,  ..., -0.0124,  0.0130,  0.0169],
        [ 0.0142,  0.0110,  0.0036,  ..., -0.0032, -0.0185,  0.0088],
        [-0.0195, -0.0056,  0.0152,  ...,  0.0174, -0.0127, -0.0168],
        ...,
        [ 0.0059,  0.0039, -0.0192,  ...,  0.0130,  0.0146,  0.0146],
        [ 0.0038, -0.0133, -0.0140,  ..., -0.0176, -0.0058, -0.0104],
        [-0.0138,  0.0053, -0.0041,  ...,  0.0020, -0.0151,  0.0133]],
       device='cuda:0')
tensor([[-1.7025e-02,  2.4742e-03, -1.0115e-02,  ...,  1.4603e-02,
         -8.0345e-03,  1.0960e-02],
        [-1.1581e-02,  1.8048e-02,  1.4734e-02,  ...,  6.3604e-03,
         -1.6242e-02,  6.1431e-03],
        [ 1.0976e-02, -4.0579e-03,  4.3461e-03,  ...,  6.3524e-03,
         -1.8358e-02, -1.0410e-02],
        ...,
        [ 7.3278e-03, -1.0956e-02,  6.2436e-03,  ..., -4.5395e-03,
          1.2014e-02,  1.5762e-02],
        [-1.1824e-02,  1.2552e-02,  2.3365e-03,  ...,  6.9467e-04,
         -1.5712e-02,  5.9579e-03],
        [-1.9212e-02, -1.9194e-02,  1.4454e-02,  ...,  4.2492e-05,
         -1.2649e-02, -1.0431e-02]], device='cuda:0')
>>> exec(open('./sanity_check.py').read())
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.15it/s]
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.11it/s]
base_model.model.model.layers.0.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.0.mlp.fc2.lora_A.default.weight: Changed
base_model.model.model.layers.1.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.1.mlp.fc2.lora_A.default.weight: Changed
base_model.model.model.layers.2.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.2.mlp.fc2.lora_A.default.weight: Changed
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]], device='cuda:0')
tensor([[-1.7242e-02,  5.6506e-03,  1.7260e-02,  ...,  1.1011e-02,
          1.8103e-02,  1.8189e-02],
        [-6.4992e-03, -1.4380e-02,  5.5341e-03,  ...,  1.6408e-02,
         -6.3714e-03,  1.2028e-02],
        [ 1.7429e-02,  7.9016e-03,  3.9765e-03,  ...,  5.9452e-03,
         -1.0812e-02, -9.6804e-05],
        ...,
        [-1.7027e-02,  1.8371e-02, -1.8877e-02,  ..., -1.1631e-03,
         -8.0132e-03,  1.7408e-02],
        [-1.0086e-02, -1.4720e-02,  1.2310e-02,  ..., -6.3289e-03,
         -1.7902e-02, -4.9857e-03],
        [ 1.3719e-02, -1.2506e-02, -1.6492e-02,  ..., -1.3931e-02,
          1.4905e-02, -1.1610e-02]], device='cuda:0')
tensor([[-0.0068,  0.0117,  0.0159,  ...,  0.0016,  0.0090,  0.0123],
        [ 0.0113, -0.0073,  0.0107,  ...,  0.0194,  0.0076,  0.0180],
        [ 0.0103,  0.0119, -0.0124,  ..., -0.0075,  0.0112, -0.0071],
        ...,
        [-0.0004, -0.0047, -0.0083,  ...,  0.0002, -0.0165, -0.0164],
        [-0.0068,  0.0153,  0.0113,  ..., -0.0020, -0.0062, -0.0197],
        [-0.0114, -0.0089,  0.0135,  ...,  0.0017,  0.0148,  0.0079]],
       device='cuda:0')

This behavior is the same loading via peftmodel or via add_adapter. merge_and_unload will seemingly do nothing, all models are identical after this function is called, no matter what the lora weights are.

Expected behavior

The trained lora weights are loaded and applied to the base model, which should be the same each time. Or, at the very least, there could be some indication that this is not my trained model.

It is possible I'm horribly misunderstanding something, but this code appears to be initializing the loras from scratch to prepare them for training (e.g. init_lora_weights=False still initializes the lora weights, just randomly rather than the standard way). I have completed training, where everything seemed normal (training loss went down, tensorboard logs look fine), and now I just want to use the weights I've trained.

Have been following this documentation and the issue seems to exist no matter how I load the loras: https://huggingface.co/docs/transformers/main/en/peft -- would love to be pointed to my misunderstanding here, since I'm sure it's something basic.

In general, your loading code looks correct. I tried to replicate your issue but for me, everything looks right. I made these changes:

MODEL_ID = "microsoft/phi-2"
# create adapters
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map={"":0},
)
lora_config = LoraConfig(init_lora_weights=False)
model = get_peft_model(model, lora_config)
model.add_adapter("other", lora_config)
model.save_pretrained("/tmp/peft/issue-1678")

path1 = "/tmp/peft/issue-1678"
path2 = "/tmp/peft/issue-1678/other"

Otherwise the code is identical. I get different weights and different logits for the two adapters. One notable thing is that the output of compare_weights is the following for me:

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: Changed
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: Changed
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: Changed
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: Changed
base_model.model.model.layers.0.mlp.fc1.lora_A.default.weight: Changed
base_model.model.model.layers.0.mlp.fc1.lora_B.default.weight: Changed

As you can see, for me, the shown weights are lora_A and lora_B in alternating fashion. This is what should be expected. For you, there is only lora_A. This indicates to me that your lora_B weights are all zeros -- could you please check that?

lora_B is expected to be 0 at initialization (except if we set init_lora_weights=False as in my code). Therefore, it appears that your adapters were not correctly trained or not correctly saved. This cannot be further investigated with the code you provided though.

I have confirmed that your code does consistently load the same initially randomized adapter weights in my environment, which is a good step, thank you for this!

I wrote some simplified trainer code to demonstrate something similar to my current training loop, which reproduces my error:

import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer

base_model = "microsoft/phi-2"
new_model = "sanity_test_issue"

dataset = load_dataset("sahil2801/CodeAlpaca-20k",split="train")

tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
tokenizer.pad_token=tokenizer.eos_token
tokenizer.padding_side="right"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules= ["Wqkv", "fc1", "fc2" ]
)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    trust_remote_code=True,
    flash_attn=True,
    flash_rotary=True,
    fused_dense=True,
    low_cpu_mem_usage=True,
    device_map={"": 0},
    revision="refs/pr/23",
    torch_dtype=torch.float16,
)

model.config.use_cache = False
model.config.pretraining_tp = 1

model = get_peft_model(model, peft_config)
model.add_adapter("default", peft_config)

model.print_trainable_parameters()

training_arguments = TrainingArguments(
        output_dir="./results-sanity-issue",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=32,
        evaluation_strategy="steps",
        eval_steps=2000,
        logging_steps=25,
        save_steps=50,
        optim="paged_adamw_8bit",
        learning_rate=2e-4,
        lr_scheduler_type="cosine",
        warmup_steps=10,
        warmup_ratio=0.05,
        report_to="tensorboard",
        weight_decay=0.01,
        max_steps=10 #-1, 
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    eval_dataset=dataset,
    dataset_text_field="instruction",
    max_seq_length=2048,
    tokenizer=tokenizer,
    args=training_arguments,
)

trainer.train()

model.save_pretrained(new_model)
#trainer.model.save_pretrained(new_model)

This leads to some interesting behaviors saving and loading the model, where there seems to be some sort of mismatch.

>>> exec(open("./train_sanity.py").read())                                                                                                                                 
2024-04-25 23:34:40.349566: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-criti
cal operations.                                                                                                                                                            
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.                                               
2024-04-25 23:34:41.510985: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT                                                
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.                                                      
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.06it/s]
trainable params: 36,700,160 || all params: 2,816,384,000 || trainable%: 1.3030950324955688                                                                                
tensor([[-1.7239e-02,  6.7696e-03,  1.1712e-02,  ...,  2.0782e-03,                                                                                                         
          8.6474e-03,  1.1055e-02],                                                                                                                                        
        [-5.1573e-03, -1.2886e-02,  7.1626e-03,  ...,  1.8335e-02,                                                                                                         
         -7.8717e-03,  1.0091e-02],                                                                                                                                        
        [-3.3074e-03,  9.1777e-04, -1.2094e-02,  ...,  1.1925e-02,                                                                                                         
         -1.1942e-02,  1.7323e-02],                                                                                                                                        
        ...,                                                                                                                                                               
        [-6.1534e-03,  5.3593e-03, -1.1107e-02,  ...,  1.7856e-02,                                                                                                         
          7.2733e-03,  1.0433e-04],                                                                                                                                        
        [ 1.3691e-02, -4.8460e-03,  1.7432e-02,  ..., -4.9365e-03,                                                                                                         
          1.1269e-02,  1.8518e-02],                                                                                                                                        
        [ 4.2621e-03,  1.9739e-02,  8.6471e-05,  ...,  7.4409e-03,                                                                                                         
         -4.0094e-03, -1.1024e-02]], device='cuda:0')                                                                                                                      
/home/ubuntu/.local/lib/python3.9/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be re
moved in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfigur
ation` instead:                                                                                                                                                            
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)                                      
  warnings.warn(                                                                                                                                                           
{'train_runtime': 101.1944, 'train_samples_per_second': 12.649, 'train_steps_per_second': 0.099, 'train_loss': 2.082432174682617, 'epoch': 0.06}                           
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:41<00:00, 10.12s/it]
You are using a model of type phi to instantiate a model of type phi-msft. This is not supported for all configurations of models and can yield errors.                                                                                                                                              
>>> trained_model_10steps = model                                                                                                                                          
>>> print(trained_model_10steps.state_dict()["base_model.model.transformer.h.0.mlp.fc1.lora_A.default.weight"])                                                            
tensor([[-0.0166,  0.0072,  0.0110,  ...,  0.0024,  0.0080,  0.0114],                                                                                                      
        [-0.0049, -0.0132,  0.0065,  ...,  0.0184, -0.0078,  0.0106],                                                                                                      
        [-0.0027,  0.0015, -0.0128,  ...,  0.0112, -0.0125,  0.0171],                                                                                                      
        ...,                                                                                                                                                               
        [-0.0056,  0.0060, -0.0117,  ...,  0.0180,  0.0066,  0.0007],                                                                                                      
        [ 0.0143, -0.0042,  0.0167,  ..., -0.0044,  0.0107,  0.0190],                                                                                                      
        [ 0.0036,  0.0191,  0.0003,  ...,  0.0069, -0.0036, -0.0109]],                                                                                                     
       device='cuda:0')                                                                                                                                                    
>>>   
>>> loaded_new_model = AutoModelForCausalLM.from_pretrained(base_model, load_in_4bit=True, torch_dtype=torch.float16, device_map={"":0})                                   
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_conf
ig` argument instead.                                                                                                                                                      
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.15it/s]
>>> loaded_new_model = PeftModel.from_pretrained(loaded_new_model, './sanity_test_issue')                                                                                  
>>>                                                                                                                                                                        
>>> print(trained_model_10steps)                                                                                                                                           
PeftModelForCausalLM(                                                                                                                                                      
  (base_model): LoraModel(                                                                                                                                                 
    (model): PhiForCausalLM(                                                                                                                                               
      (transformer): PhiModel(                                                                                                                                             
        (embd): Embedding(                                                                                                                                                 
          (wte): Embedding(51200, 2560)                                                                                                                                    
          (drop): Dropout(p=0.0, inplace=False)                                                                                                                            
        )                                                                                                                                                                  
        (h): ModuleList(                                                                                                                                                   
          (0-31): 32 x ParallelBlock(                                                                                                                                      
            (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)                                                                                                   
            (resid_dropout): Dropout(p=0.1, inplace=False)                                                                                                                 
            (mixer): MHA(                                                                                                                                                  
              (rotary_emb): RotaryEmbedding()                                                                                                                              
              (Wqkv): lora.Linear4bit(                                                                                                                                     
                (base_layer): Linear4bit(in_features=2560, out_features=7680, bias=True)                                                                                   
                (lora_dropout): ModuleDict(                                                                                                                                
                  (default): Dropout(p=0.05, inplace=False)                                                                                                                
                )                                                                                                                                                          
                (lora_A): ModuleDict(                                                                                                                                      
                  (default): Linear(in_features=2560, out_features=32, bias=False)                                                                                         
                )                                                                                                                                                          
                (lora_B): ModuleDict(                                                                                                                                      
                  (default): Linear(in_features=32, out_features=7680, bias=False)                                                                                         
                )                                                                                                                                                          
                (lora_embedding_A): ParameterDict()                                                                                                                        
                (lora_embedding_B): ParameterDict()                                                                                                                        
              )                                                                                                                                                            
              (out_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)                                                                                       
              (inner_attn): SelfAttention(                                                                                                                                 
                (drop): Dropout(p=0.0, inplace=False)                                                                                                                      
              )                                                                                                                                                            
              (inner_cross_attn): CrossAttention(                                                                                                                          
                (drop): Dropout(p=0.0, inplace=False)                                                                                                                      
              )                                                                                                                                                            
            )                                                                                                                                                              
            (mlp): MLP(                                                                                                                                                    
              (fc1): lora.Linear4bit(                                                                                                                                      
                (base_layer): Linear4bit(in_features=2560, out_features=10240, bias=True)                                                                                  
                (lora_dropout): ModuleDict(                                                                                                                                
                  (default): Dropout(p=0.05, inplace=False)                                                                                                                
                )                                                                                                                                                          
                (lora_A): ModuleDict(                                                                                                                                      
                  (default): Linear(in_features=2560, out_features=32, bias=False)                                                                                         
                )                                                                                                                                                          
                (lora_B): ModuleDict(                                                                                                                                      
                  (default): Linear(in_features=32, out_features=10240, bias=False)                                                                                        
                )                                                                                                                                                          
                (lora_embedding_A): ParameterDict()                                                                                                                        
                (lora_embedding_B): ParameterDict()                                                                                                                        
              )                                                                                                                                                            
              (fc2): lora.Linear4bit(                                                                                                                                      
                (base_layer): Linear4bit(in_features=10240, out_features=2560, bias=True)                                                                                  
                (lora_dropout): ModuleDict(                                                                                                                                
                  (default): Dropout(p=0.05, inplace=False)                                                                                                                
                )                                                                                                                                                          
                (lora_A): ModuleDict(                                                                                                                                      
                  (default): Linear(in_features=10240, out_features=32, bias=False)                                                                                        
                )                                                                                                                                                          
                (lora_B): ModuleDict(                                                                                                                                      
                  (default): Linear(in_features=32, out_features=2560, bias=False)                                                                                         
                )                                                                                                                                                          
                (lora_embedding_A): ParameterDict()                                                                                                                        
                (lora_embedding_B): ParameterDict()                                                                                                                        
              )                                                                                                                                                            
              (act): NewGELUActivation()
            )
          )
        )
      )
      (lm_head): CausalLMHead(
        (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (linear): Linear(in_features=2560, out_features=51200, bias=True)
      )
      (loss): CausalLMLoss(
        (loss_fct): CrossEntropyLoss()
      )
    )
  )
)
>>> print(loaded_new_model)
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PhiForCausalLM(
      (model): PhiModel(
        (embed_tokens): Embedding(51200, 2560)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x PhiDecoderLayer(
            (self_attn): PhiAttention(
              (q_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)
              (k_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)
              (v_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)
              (dense): Linear4bit(in_features=2560, out_features=2560, bias=True)
              (rotary_emb): PhiRotaryEmbedding()
            )
            (mlp): PhiMLP(
              (activation_fn): NewGELUActivation()
              (fc1): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2560, out_features=10240, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=10240, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (fc2): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=10240, out_features=2560, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=10240, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
            )
            (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (final_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      )
      (lm_head): Linear(in_features=2560, out_features=51200, bias=True)
    )
  )
)
>>> lora0_model1 = trained_model_10steps.state_dict()["base_model.model.transformer.h.0.mlp.fc1.lora_A.default.weight"]
>>> lora0_model2 = loaded_new_model.state_dict()["base_model.model.model.layers.0.mlp.fc1.lora_A.default.weight"]
>>> 
>>> lora0_model1
tensor([[-0.0166,  0.0072,  0.0110,  ...,  0.0024,  0.0080,  0.0114],
        [-0.0049, -0.0132,  0.0065,  ...,  0.0184, -0.0078,  0.0106],
        [-0.0027,  0.0015, -0.0128,  ...,  0.0112, -0.0125,  0.0171],
        ...,
        [-0.0056,  0.0060, -0.0117,  ...,  0.0180,  0.0066,  0.0007],
        [ 0.0143, -0.0042,  0.0167,  ..., -0.0044,  0.0107,  0.0190],
        [ 0.0036,  0.0191,  0.0003,  ...,  0.0069, -0.0036, -0.0109]],
       device='cuda:0')
>>> lora0_model2
tensor([[-0.0046,  0.0158, -0.0063,  ..., -0.0154, -0.0120, -0.0161],
        [-0.0189, -0.0069, -0.0012,  ..., -0.0189, -0.0027,  0.0193],
        [ 0.0148, -0.0178,  0.0023,  ..., -0.0124,  0.0009,  0.0190],
        ...,
        [-0.0091,  0.0163,  0.0035,  ...,  0.0194, -0.0172, -0.0062],
        [ 0.0102,  0.0064, -0.0009,  ...,  0.0104, -0.0039,  0.0132],
        [-0.0089,  0.0070, -0.0086,  ..., -0.0174,  0.0104,  0.0089]],
       device='cuda:0')

You can see this particular lora layer changes before and after training, but the loaded weights are not correct. In addition, the names of several of the layers all seem to be changed somehow. Behavior is the same with trainer.model.save_pretrained and model.save_pretrained.

In addition, I checked lora_B and it's nonzero in the trained model and zero in the loaded model (and before training).

Any chance this could be related to the save_pretrained function? The training looks like it's working, and the loading looks correct using the example you provided. Not sure what else could be causing this!

I could indeed reproduce your results. The issue seems to be that when you first load the model, you load a specific revision, revision="refs/pr/23", but when you try to load the saved model, you load the base model without revision. If you load the base model exactly the same as you did before training, you will see that the loaded weights are identical.

After that change, I get:

sd0 = trained_model_10steps.state_dict()
sd1 = loaded_new_model.state_dict()
out = [torch.allclose(sd0[k], sd1[k]) for k in sd0 if "lora_" in k]
print(all(out))  # True

This was it, thank you so much for your help!!

Is PeftModel.from_pretrained(wrong_revision, right_revision_checkpoint)'s intended behavior this sort of silently reinitializing an adapter? Intuitively it seems like some sort of error or warning may be helpful here.

Thanks again, I will close the issue with this comment.