ndif-team/nnsight

Requesting support for input_embeds for tracer.invoke()

Closed this issue · 4 comments

I'm using the LanguageModel class to wrap a vision-language model LLaVA, and during the execution of

with tracer.invoke(inputs)

nnsight/contexts/Invoker.py#L55:

self.inputs, batch_size = self.tracer._model._prepare_inputs(
  *self.inputs, **self.kwargs
)

results in errors.
FYI, a typical input to LLaVA is

forward(
  input_ids: torch.LongTensor,
  images: Optional[torch.FloatTensor],
  **kwargs
)

or

forward(
  inputs_embeds,
  **kwargs
)

Can you add support to accept inputs_embeds as an alternative to inputs so that I can use the code in the following way?

with tracer.invoke(
    inputs=None, 
    inputs_embeds=inputs_embeds,
):

Here's an example wrapper to get models with different .forward() args working with NNsight.

def transformerlens_to_nnsight_wrapper(original_method):
    def wrapper(self, *args, **kwargs):
        if "input_ids" in kwargs:
            kwargs["input"] = kwargs.pop("input_ids")
        _ = kwargs.pop("labels", None)
        _ = kwargs.pop("attention_mask", None)
        return original_method(self, *args, **kwargs)
    return wrapper

# Bind the wrapped method to only this instance
tl_model.forward = transformerlens_to_nnsight_wrapper(HookedTransformer.forward).__get__(tl_model, HookedTransformer)
tl_model.generate = transformerlens_to_nnsight_wrapper(HookedTransformer.generate).__get__(tl_model, HookedTransformer)

# Also set a few attributes, so that it works with NNsight
tl_model.device = tl_model.cfg.device

I'll look into a solution for making this easier/more documented.

From my understanding to your code, it didn't solve the problem that the model does not take image into the input. LLaVA model only works if input has 1. both input_ids and images or 2. has inputs_embeds as image tokens does not have a specific embedding in the embedding matrix but is calculated from the image encoder and a linear layer.

@HuFY-dev I think the right thing to do here is subclass LanguageModel and implement your own virtual methods to handle the inputs you want to give them. It shouldn't be too much effort.

I'll try that. I just think inputs_embeds is a standard argument for the transformer pretrained model classes for the .forward() method (not only VLMs but most LLMs as well) and adding support to inputs_embeds can make the code more flexible. Maybe I'll write some code on my side and submit a PR if things are working properly.