huggingface/peft

Updating only one adapter using multi adapter

LameloBally opened this issue · 2 comments

I am curious if I can update only one adapter in multi adapter siutations.

For example, I have two adapter A and B.

In forward pass, LLM + adapter A + adapter B to get Loss,

and update only adapter A using peft.

I found _mixed_batch_forward in your code, is it related to my question?

def _mixed_batch_forward(
        self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
    ) -> torch.Tensor:
        # This is a special method that handles the case when users pass the argument `adapter_names`. This is an
        # extra argument that allows mixing different adapters in the same batch at inference time.
        result = self.base_layer(x, *args, **kwargs)
        torch_result_dtype = result.dtype

        unique_adapters = set(adapter_names)
        sub_batch_indices_list = []
        for adapter in unique_adapters:
            sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

        for i, active_adapter in enumerate(unique_adapters):
            if active_adapter == "__base__":
                continue
            if active_adapter not in self.lora_A.keys():
                continue

            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]

            # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
            # layer output
            sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
            lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
            result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)

        return result

Mixed batches are not related, find a description of mixed adapter batches in the docs.

For what you want, I think what you would need is to manually set requires_grad=False on the B adapter. Also, just in case you haven't done so already, you have to ensure that adapter A and B are really active at the same time.

Note that we normally only train with a single adapter at a time, so your specific use case is a bit of an outlier and may run into some problems I haven't thought of yet.