Skip to content

Commit 6c08dd6

Browse files
fy-mengfacebook-github-bot
authored andcommitted
Allow more variable input to LLMAttributionResult (#1627)
Summary: 1. Allow `LLMAttributionResult` to be initialized with generic array data (lists, np.ndarray) and perform sanity checks on their shapes; 2. During visualization, the text tokens are now `repr`'d to make sure that non-word charactures (e.g. newline) are visualized correctly. Reviewed By: craymichael Differential Revision: D78197863
1 parent cc5b9ce commit 6c08dd6

File tree

1 file changed

+76
-6
lines changed

1 file changed

+76
-6
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from abc import ABC
66
from copy import copy
77
from dataclasses import dataclass
8-
from textwrap import shorten
8+
from textwrap import dedent, shorten
99
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
1010

1111
import matplotlib.colors as mcolors
1212

1313
import matplotlib.pyplot as plt
1414
import numpy as np
15+
import numpy.typing as npt
1516

1617
import torch
1718
from captum._utils.typing import TokenizerLike
@@ -51,11 +52,75 @@ class LLMAttributionResult:
5152
It also provides utilities to help present and plot the result in different forms.
5253
"""
5354

54-
seq_attr: Tensor
55-
token_attr: Optional[Tensor]
5655
input_tokens: List[str]
5756
output_tokens: List[str]
57+
<<<<<<< dest: 0573f2a25ed3 - ivanov: Cleanup: remove outdated TODO comment
5858
output_probs: Optional[Tensor] = None
59+
||||||| base: ec37f4ace4b6 - ymatharu: [SS Omni] Generate explicit features...
60+
=======
61+
# pyre-ignore[13]: initialized via a property setter
62+
_seq_attr: Tensor
63+
_token_attr: Optional[Tensor] = None
64+
65+
def __init__(
66+
self,
67+
seq_attr: npt.ArrayLike,
68+
token_attr: Optional[npt.ArrayLike],
69+
input_tokens: List[str],
70+
output_tokens: List[str],
71+
) -> None:
72+
self.input_tokens = input_tokens
73+
self.output_tokens = output_tokens
74+
self.seq_attr = seq_attr
75+
self.token_attr = token_attr
76+
77+
@property
78+
def seq_attr(self) -> Tensor:
79+
return self._seq_attr
80+
81+
@seq_attr.setter
82+
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
83+
if isinstance(seq_attr, Tensor):
84+
self._seq_attr = seq_attr
85+
else:
86+
self._seq_attr = torch.tensor(seq_attr)
87+
# IDEA: in the future we might want to support higher dim seq_attr
88+
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
89+
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
90+
91+
assert (
92+
len(self.input_tokens) == self._seq_attr.shape[0]
93+
), "seq_attr and input_tokens must have the same length"
94+
95+
@property
96+
def token_attr(self) -> Optional[Tensor]:
97+
return self._token_attr
98+
99+
@token_attr.setter
100+
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
101+
if isinstance(token_attr, Tensor):
102+
self._token_attr = token_attr
103+
elif token_attr is None:
104+
# can't combine with previous clause, linter unhappy ¯\_(ツ)_/¯
105+
self._token_attr = None
106+
else:
107+
self._token_attr = torch.tensor(token_attr)
108+
# IDEA: in the future we might want to support higher dim seq_attr
109+
if self._token_attr is not None:
110+
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
111+
112+
if self._token_attr is not None:
113+
assert self._token_attr.shape == (
114+
len(self.output_tokens),
115+
len(self.input_tokens),
116+
), dedent(
117+
f"""\
118+
Expect token_attr to have shape
119+
{len(self.output_tokens), len(self.input_tokens)},
120+
got {self._token_attr.shape}
121+
"""
122+
)
123+
>>>>>>> source: 5b076627efd3 - fymeng: [Captum] Allow more variable input to ...
59124

60125
@property
61126
def seq_attr_dict(self) -> Dict[str, float]:
@@ -125,10 +190,14 @@ def plot_token_attr(
125190

126191
# Show all ticks and label them with the respective list entries.
127192
shortened_tokens = [
128-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
193+
shorten(repr(t)[1:-1], width=50, placeholder="...")
194+
for t in self.input_tokens
129195
]
130196
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
131-
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
197+
ax.set_yticks(
198+
np.arange(data.shape[0]),
199+
labels=[repr(token)[1:-1] for token in self.output_tokens],
200+
)
132201

133202
# Let the horizontal axes labeling appear on top.
134203
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
@@ -176,7 +245,8 @@ def plot_seq_attr(
176245
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
177246

178247
shortened_tokens = [
179-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
248+
shorten(repr(t)[1:-1], width=50, placeholder="...")
249+
for t in self.input_tokens
180250
]
181251
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
182252

0 commit comments

Comments
 (0)