Skip to content

Commit aa52b26

Browse files
fy-mengfacebook-github-bot
authored andcommitted
Allow more variable input to LLMAttributionResult (#1627)
Summary: Pull Request resolved: #1627 1. Allow `LLMAttributionResult` to be initialized with generic array data (lists, np.ndarray) and perform sanity checks on their shapes; 2. During visualization, the text tokens are now `repr`'d to make sure that non-word charactures (e.g. newline) are visualized correctly. Reviewed By: craymichael Differential Revision: D78197863
1 parent cc5b9ce commit aa52b26

File tree

1 file changed

+102
-15
lines changed

1 file changed

+102
-15
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from abc import ABC
66
from copy import copy
77
from dataclasses import dataclass
8-
from textwrap import shorten
8+
from textwrap import dedent, shorten
99
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
1010

1111
import matplotlib.colors as mcolors
1212

1313
import matplotlib.pyplot as plt
1414
import numpy as np
15+
import numpy.typing as npt
1516

1617
import torch
1718
from captum._utils.typing import TokenizerLike
@@ -51,11 +52,92 @@ class LLMAttributionResult:
5152
It also provides utilities to help present and plot the result in different forms.
5253
"""
5354

54-
seq_attr: Tensor
55-
token_attr: Optional[Tensor]
5655
input_tokens: List[str]
5756
output_tokens: List[str]
58-
output_probs: Optional[Tensor] = None
57+
# pyre-ignore[13]: initialized via a property setter
58+
_seq_attr: Tensor
59+
_token_attr: Optional[Tensor] = None
60+
_output_probs: Optional[Tensor] = None
61+
62+
def __init__(
63+
self,
64+
*,
65+
input_tokens: List[str],
66+
output_tokens: List[str],
67+
seq_attr: npt.ArrayLike,
68+
token_attr: Optional[npt.ArrayLike] = None,
69+
output_probs: Optional[npt.ArrayLike] = None,
70+
) -> None:
71+
self.input_tokens = input_tokens
72+
self.output_tokens = output_tokens
73+
self.seq_attr = seq_attr
74+
self.token_attr = token_attr
75+
self.output_probs = output_probs
76+
77+
@property
78+
def seq_attr(self) -> Tensor:
79+
return self._seq_attr
80+
81+
@seq_attr.setter
82+
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
83+
if isinstance(seq_attr, Tensor):
84+
self._seq_attr = seq_attr
85+
else:
86+
self._seq_attr = torch.tensor(seq_attr)
87+
# IDEA: in the future we might want to support higher dim seq_attr
88+
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
89+
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
90+
assert (
91+
len(self.input_tokens) == self._seq_attr.shape[0]
92+
), "seq_attr and input_tokens must have the same length"
93+
94+
@property
95+
def token_attr(self) -> Optional[Tensor]:
96+
return self._token_attr
97+
98+
@token_attr.setter
99+
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
100+
if token_attr is None:
101+
self._token_attr = None
102+
elif isinstance(token_attr, Tensor):
103+
self._token_attr = token_attr
104+
else:
105+
self._token_attr = torch.tensor(token_attr)
106+
107+
if self._token_attr is not None:
108+
# IDEA: in the future we might want to support higher dim seq_attr
109+
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
110+
assert self._token_attr.shape == (
111+
len(self.output_tokens),
112+
len(self.input_tokens),
113+
), dedent(
114+
f"""\
115+
Expect token_attr to have shape
116+
{len(self.output_tokens), len(self.input_tokens)},
117+
got {self._token_attr.shape}
118+
"""
119+
)
120+
121+
@property
122+
def output_probs(self) -> Optional[Tensor]:
123+
return self._output_probs
124+
125+
@output_probs.setter
126+
def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None:
127+
if output_probs is None:
128+
self._output_probs = None
129+
elif isinstance(output_probs, Tensor):
130+
self._output_probs = output_probs
131+
else:
132+
self._output_probs = torch.tensor(output_probs)
133+
134+
if self._output_probs is not None:
135+
assert (
136+
len(self._output_probs.shape) == 1
137+
), "output_probs must be a 1D tensor"
138+
assert (
139+
len(self.output_tokens) == self._output_probs.shape[0]
140+
), "seq_attr and input_tokens must have the same length"
59141

60142
@property
61143
def seq_attr_dict(self) -> Dict[str, float]:
@@ -125,10 +207,14 @@ def plot_token_attr(
125207

126208
# Show all ticks and label them with the respective list entries.
127209
shortened_tokens = [
128-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
210+
shorten(repr(t)[1:-1], width=50, placeholder="...")
211+
for t in self.input_tokens
129212
]
130213
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
131-
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
214+
ax.set_yticks(
215+
np.arange(data.shape[0]),
216+
labels=[repr(token)[1:-1] for token in self.output_tokens],
217+
)
132218

133219
# Let the horizontal axes labeling appear on top.
134220
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
@@ -176,7 +262,8 @@ def plot_seq_attr(
176262
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
177263

178264
shortened_tokens = [
179-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
265+
shorten(repr(t)[1:-1], width=50, placeholder="...")
266+
for t in self.input_tokens
180267
]
181268
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
182269

@@ -690,12 +777,12 @@ def attribute(
690777
attr = inp.format_attr(attr)
691778

692779
return LLMAttributionResult(
693-
attr[0],
694-
(
780+
seq_attr=attr[0],
781+
token_attr=(
695782
attr[1:] if self.include_per_token_attr else None
696783
), # shape(n_output_token, n_input_features)
697-
inp.values,
698-
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
784+
input_tokens=inp.values,
785+
output_tokens=_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
699786
)
700787

701788
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
@@ -830,10 +917,10 @@ def attribute(
830917
seq_attr = attr.sum(0)
831918

832919
return LLMAttributionResult(
833-
seq_attr,
834-
attr, # shape(n_output_token, n_input_features)
835-
inp.values,
836-
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
920+
seq_attr=seq_attr,
921+
token_attr=attr, # shape(n_output_token, n_input_features)
922+
input_tokens=inp.values,
923+
output_tokens=_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
837924
)
838925

839926
def attribute_future(self) -> Callable[[], LLMAttributionResult]:

0 commit comments

Comments
 (0)