From 73dc20f7bf71e467c32e2cbfd339a63d63d89bae Mon Sep 17 00:00:00 2001 From: Sultan <92447824+SulRash@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:16:43 +0300 Subject: [PATCH 1/5] Update llm_attr.py Fixed an issue where when trying to attribute a single target token for models that don't generate a bos you end up with an empty target list --- captum/attr/_core/llm_attr.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 194a910765..8341cf695b 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -327,6 +327,7 @@ def attribute( inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, num_trials: int = 1, + skip_bos: bool = True, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict[, ]` to avoid runtime subscripting # errors. @@ -382,8 +383,11 @@ def attribute( assert gen_args is None, "gen_args must be None when target is given" if type(target) is str: - # exclude sos - target_tokens = self.tokenizer.encode(target)[1:] + # exclude sos / bos + if skip_bos: + target_tokens = self.tokenizer.encode(target)[1:] + else: + target_tokens = self.tokenizer.encode(target) target_tokens = torch.tensor(target_tokens) elif type(target) is torch.Tensor: target_tokens = target From 3b8b6aa17dccb64209aeb6e40f263e06045aced5 Mon Sep 17 00:00:00 2001 From: Sultan <92447824+SulRash@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:20:15 +0300 Subject: [PATCH 2/5] Update llm_attr.py Added some error handling for when the argument should be used --- captum/attr/_core/llm_attr.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 8341cf695b..b037dbcb0c 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -290,9 +290,12 @@ def _forward_func( # 1st element is the total prob, rest are the target tokens # add a leading dim for batch even we only support single instance for now if self.include_per_token_attr: - target_log_probs = torch.stack( - [total_log_prob, *log_prob_list], dim=0 - ).unsqueeze(0) + try: + target_log_probs = torch.stack( + [total_log_prob, *log_prob_list], dim=0 + ).unsqueeze(0) + except TypeError: + print("It seems like you got an empty list of target tokens. If you are attributing only one target token (a single character / word) try using the skip_bos argument in the attribute function.") else: target_log_probs = total_log_prob # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int, From 136b7da58251d7b7e3a258567360e6d7a3b440f9 Mon Sep 17 00:00:00 2001 From: Sultan <92447824+SulRash@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:22:42 +0300 Subject: [PATCH 3/5] Update llm_attr.py Added exit to try except loop --- captum/attr/_core/llm_attr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index b037dbcb0c..dedbc5a6f5 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -296,6 +296,7 @@ def _forward_func( ).unsqueeze(0) except TypeError: print("It seems like you got an empty list of target tokens. If you are attributing only one target token (a single character / word) try using the skip_bos argument in the attribute function.") + exit() else: target_log_probs = total_log_prob # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int, From 333852c59b43ac0cd5b3b2e5cefed07c69d6bc38 Mon Sep 17 00:00:00 2001 From: Sultan <92447824+SulRash@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:10:32 +0300 Subject: [PATCH 4/5] Changed error handling --- captum/attr/_core/llm_attr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index dedbc5a6f5..61fcc51126 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -295,8 +295,7 @@ def _forward_func( [total_log_prob, *log_prob_list], dim=0 ).unsqueeze(0) except TypeError: - print("It seems like you got an empty list of target tokens. If you are attributing only one target token (a single character / word) try using the skip_bos argument in the attribute function.") - exit() + raise TypeError("It seems like you got an empty list of target tokens. If you are attributing only one target token (a single character or word) try using the skip_bos argument in the attribute function.") else: target_log_probs = total_log_prob # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int, From 90c13113af0f5a7257c6e1dad6d5794f2dd047f5 Mon Sep 17 00:00:00 2001 From: Sultan Date: Fri, 30 Aug 2024 00:47:31 +0300 Subject: [PATCH 5/5] Fixed ufmt and linting check --- captum/attr/_core/llm_attr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 2733dc647b..f90492e39d 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -291,10 +291,10 @@ def _forward_func( if self.include_per_token_attr: try: target_log_probs = torch.stack( - [total_log_prob, *log_prob_list], dim=0 # type: ignore - ).unsqueeze(0) + [total_log_prob, *log_prob_list], dim=0 # type: ignore + ).unsqueeze(0) except TypeError: - raise TypeError("It seems like you got an empty list of target tokens. If you are attributing only one target token (a single character or word) try using the skip_bos argument in the attribute function.") + raise TypeError("Try using the skip_bos argument.") else: target_log_probs = total_log_prob # type: ignore target_probs = torch.exp(target_log_probs)