Skip to content

Commit bf0641d

Browse files
fy-mengfacebook-github-bot
authored andcommitted
Allow more variable input to LLMAttributionResult (#1627)
Summary: Pull Request resolved: #1627 1. Allow `LLMAttributionResult` to be initialized with generic array data (lists, np.ndarray) and perform sanity checks on their shapes; 2. During visualization, the text tokens are now `repr`'d to make sure that non-word charactures (e.g. newline) are visualized correctly. Reviewed By: craymichael Differential Revision: D78197863
1 parent cc5b9ce commit bf0641d

File tree

1 file changed

+72
-6
lines changed

1 file changed

+72
-6
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from abc import ABC
66
from copy import copy
77
from dataclasses import dataclass
8-
from textwrap import shorten
8+
from textwrap import dedent, shorten
99
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
1010

1111
import matplotlib.colors as mcolors
1212

1313
import matplotlib.pyplot as plt
1414
import numpy as np
15+
import numpy.typing as npt
1516

1617
import torch
1718
from captum._utils.typing import TokenizerLike
@@ -51,11 +52,71 @@ class LLMAttributionResult:
5152
It also provides utilities to help present and plot the result in different forms.
5253
"""
5354

54-
seq_attr: Tensor
55-
token_attr: Optional[Tensor]
5655
input_tokens: List[str]
5756
output_tokens: List[str]
5857
output_probs: Optional[Tensor] = None
58+
# pyre-ignore[13]: initialized via a property setter
59+
_seq_attr: Tensor
60+
_token_attr: Optional[Tensor] = None
61+
62+
def __init__(
63+
self,
64+
seq_attr: npt.ArrayLike,
65+
token_attr: Optional[npt.ArrayLike],
66+
input_tokens: List[str],
67+
output_tokens: List[str],
68+
) -> None:
69+
self.input_tokens = input_tokens
70+
self.output_tokens = output_tokens
71+
self.seq_attr = seq_attr
72+
self.token_attr = token_attr
73+
74+
@property
75+
def seq_attr(self) -> Tensor:
76+
return self._seq_attr
77+
78+
@seq_attr.setter
79+
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
80+
if isinstance(seq_attr, Tensor):
81+
self._seq_attr = seq_attr
82+
else:
83+
self._seq_attr = torch.tensor(seq_attr)
84+
# IDEA: in the future we might want to support higher dim seq_attr
85+
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
86+
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
87+
88+
assert (
89+
len(self.input_tokens) == self._seq_attr.shape[0]
90+
), "seq_attr and input_tokens must have the same length"
91+
92+
@property
93+
def token_attr(self) -> Optional[Tensor]:
94+
return self._token_attr
95+
96+
@token_attr.setter
97+
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
98+
if isinstance(token_attr, Tensor):
99+
self._token_attr = token_attr
100+
elif token_attr is None:
101+
# can't combine with previous clause, linter unhappy ¯\_(ツ)_/¯
102+
self._token_attr = None
103+
else:
104+
self._token_attr = torch.tensor(token_attr)
105+
# IDEA: in the future we might want to support higher dim seq_attr
106+
if self._token_attr is not None:
107+
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
108+
109+
if self._token_attr is not None:
110+
assert self._token_attr.shape == (
111+
len(self.output_tokens),
112+
len(self.input_tokens),
113+
), dedent(
114+
f"""\
115+
Expect token_attr to have shape
116+
{len(self.output_tokens), len(self.input_tokens)},
117+
got {self._token_attr.shape}
118+
"""
119+
)
59120

60121
@property
61122
def seq_attr_dict(self) -> Dict[str, float]:
@@ -125,10 +186,14 @@ def plot_token_attr(
125186

126187
# Show all ticks and label them with the respective list entries.
127188
shortened_tokens = [
128-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
189+
shorten(repr(t)[1:-1], width=50, placeholder="...")
190+
for t in self.input_tokens
129191
]
130192
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
131-
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
193+
ax.set_yticks(
194+
np.arange(data.shape[0]),
195+
labels=[repr(token)[1:-1] for token in self.output_tokens],
196+
)
132197

133198
# Let the horizontal axes labeling appear on top.
134199
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
@@ -176,7 +241,8 @@ def plot_seq_attr(
176241
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
177242

178243
shortened_tokens = [
179-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
244+
shorten(repr(t)[1:-1], width=50, placeholder="...")
245+
for t in self.input_tokens
180246
]
181247
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
182248

0 commit comments

Comments
 (0)