Uninitialized target_tokens Variable in LLMGradientAttribution.attribute Method
MSingh-CSE opened this issue ยท 0 comments
๐ Bug
In the LLMGradientAttribution.attribute located in the llm_attr.py file, there is an issue where the target_tokens variable might not get initialized if the target argument passed to the function is neither a string (str) nor a PyTorch Tensor (torch.Tensor) of tokens. This could lead to an error further down in the method where target_tokens is used, as it relies on being initialized within the conditional branches for str and torch.Tensor types.
To Reproduce
Steps to reproduce the behavior:
- Invoke the attribute method of the LLMGradientAttribution class, passing a target argument of a type different from str or torch.Tensor.
- Since no initialization path exists for target_tokens outside the conditional checks for str and torch.Tensor, observe the potential for an uninitialized variable error.
Example code:
from captum.attr import LayerIntegratedGradients, LLMGradientAttribution, TextTokenInput
ig = LayerIntegratedGradients(model, model.model.embed_tokens)
llm_attr = LLMGradientAttribution (ig , tokenizer)
inp = TextTokenInput("2+2", tokenizer)
target = 4
llm_attr.attribute(inp, target=target)
Output:
529
530 attr_list = []
--> 531 for cur_target_idx, _ in enumerate(target_tokens):
532 # attr in shape(batch_size, input+output_len, emb_dim)
533 attr = self.attr_method.attribute(
UnboundLocalError: local variable 'target_tokens' referenced before assignment
Expected behavior
The method should either handle the scenario where target is of an unexpected type by initializing target_tokens to a sensible default or by raising a clear and descriptive error message that informs the user about the acceptable types for target.
Environment
- Captum / PyTorch Version: 0.7.0 / 2.1.2
- OS : Linux
- How you installed Captum / PyTorch (`conda`, `pip`, source): pip
- Python version: 3.10.13
Possible fix using assert:
Current code [llm_attr.py lines: 537: 545]:
else:
assert gen_args is None, "gen_args must be None when target is given"
if type(target) is str:
# exclude sos
target_tokens = self.tokenizer.encode(target)[1:]
target_tokens = torch.tensor(target_tokens)
elif type(target) is torch.Tensor:
target_tokens = target
After fix:
else:
assert gen_args is None, "gen_args must be None when target is given"
assert isinstance(target, (str, torch.Tensor)), (
"The target argument must be either a str or torch.Tensor."
)
if type(target) is str:
# exclude sos
target_tokens = self.tokenizer.encode(target)[1:]
target_tokens = torch.tensor(target_tokens)
elif type(target) is torch.Tensor:
target_tokens = target