pytorch/captum

Uninitialized target_tokens Variable in LLMGradientAttribution.attribute Method

MSingh-CSE opened this issue ยท 0 comments

๐Ÿ› Bug

In the LLMGradientAttribution.attribute located in the llm_attr.py file, there is an issue where the target_tokens variable might not get initialized if the target argument passed to the function is neither a string (str) nor a PyTorch Tensor (torch.Tensor) of tokens. This could lead to an error further down in the method where target_tokens is used, as it relies on being initialized within the conditional branches for str and torch.Tensor types.

To Reproduce

Steps to reproduce the behavior:

  1. Invoke the attribute method of the LLMGradientAttribution class, passing a target argument of a type different from str or torch.Tensor.
  2. Since no initialization path exists for target_tokens outside the conditional checks for str and torch.Tensor, observe the potential for an uninitialized variable error.

Example code:

from captum.attr import LayerIntegratedGradients, LLMGradientAttribution, TextTokenInput
ig = LayerIntegratedGradients(model, model.model.embed_tokens)
llm_attr = LLMGradientAttribution (ig , tokenizer)
inp = TextTokenInput("2+2", tokenizer)
target = 4
llm_attr.attribute(inp, target=target)

Output:

    529 
    530         attr_list = []
--> 531         for cur_target_idx, _ in enumerate(target_tokens):
    532             # attr in shape(batch_size, input+output_len, emb_dim)
    533             attr = self.attr_method.attribute(

UnboundLocalError: local variable 'target_tokens' referenced before assignment

Expected behavior

The method should either handle the scenario where target is of an unexpected type by initializing target_tokens to a sensible default or by raising a clear and descriptive error message that informs the user about the acceptable types for target.

Environment

 - Captum / PyTorch Version: 0.7.0 / 2.1.2
 - OS : Linux
 - How you installed Captum / PyTorch (`conda`, `pip`, source): pip
 - Python version: 3.10.13

Possible fix using assert:

Current code [llm_attr.py lines: 537: 545]:

else:
            assert gen_args is None, "gen_args must be None when target is given"

            if type(target) is str:
                # exclude sos
                target_tokens = self.tokenizer.encode(target)[1:]
                target_tokens = torch.tensor(target_tokens)
            elif type(target) is torch.Tensor:
                target_tokens = target

After fix:

else:
            assert gen_args is None, "gen_args must be None when target is given"
            assert isinstance(target, (str, torch.Tensor)), (
                      "The target argument must be either a str or torch.Tensor."
            )

            if type(target) is str:
                # exclude sos
                target_tokens = self.tokenizer.encode(target)[1:]
                target_tokens = torch.tensor(target_tokens)
            elif type(target) is torch.Tensor:
                target_tokens = target