Skip to content

Commit d9bbe77

Browse files
fy-mengfacebook-github-bot
authored andcommitted
Allow more variable input to LLMAttributionResult (pytorch#1627)
Summary: Pull Request resolved: pytorch#1627 1. Allow `LLMAttributionResult` to be initialized with generic array data (lists, np.ndarray) and perform sanity checks on their shapes; 2. During visualization, the text tokens are now `repr`'d to make sure that non-word charactures (e.g. newline) are visualized correctly. Reviewed By: craymichael Differential Revision: D78197863
1 parent cc5b9ce commit d9bbe77

File tree

1 file changed

+93
-7
lines changed

1 file changed

+93
-7
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from abc import ABC
66
from copy import copy
77
from dataclasses import dataclass
8-
from textwrap import shorten
8+
from textwrap import dedent, shorten
99
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
1010

1111
import matplotlib.colors as mcolors
1212

1313
import matplotlib.pyplot as plt
1414
import numpy as np
15+
import numpy.typing as npt
1516

1617
import torch
1718
from captum._utils.typing import TokenizerLike
@@ -51,11 +52,91 @@ class LLMAttributionResult:
5152
It also provides utilities to help present and plot the result in different forms.
5253
"""
5354

54-
seq_attr: Tensor
55-
token_attr: Optional[Tensor]
5655
input_tokens: List[str]
5756
output_tokens: List[str]
58-
output_probs: Optional[Tensor] = None
57+
# pyre-ignore[13]: initialized via a property setter
58+
_seq_attr: Tensor
59+
_token_attr: Optional[Tensor] = None
60+
_output_probs: Optional[Tensor] = None
61+
62+
def __init__(
63+
self,
64+
seq_attr: npt.ArrayLike,
65+
token_attr: Optional[npt.ArrayLike],
66+
output_probs: Optional[npt.ArrayLike],
67+
input_tokens: List[str],
68+
output_tokens: List[str],
69+
) -> None:
70+
self.input_tokens = input_tokens
71+
self.output_tokens = output_tokens
72+
self.seq_attr = seq_attr
73+
self.token_attr = token_attr
74+
self.output_probs = output_probs
75+
76+
@property
77+
def seq_attr(self) -> Tensor:
78+
return self._seq_attr
79+
80+
@seq_attr.setter
81+
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
82+
if isinstance(seq_attr, Tensor):
83+
self._seq_attr = seq_attr
84+
else:
85+
self._seq_attr = torch.tensor(seq_attr)
86+
# IDEA: in the future we might want to support higher dim seq_attr
87+
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
88+
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
89+
assert (
90+
len(self.input_tokens) == self._seq_attr.shape[0]
91+
), "seq_attr and input_tokens must have the same length"
92+
93+
@property
94+
def token_attr(self) -> Optional[Tensor]:
95+
return self._token_attr
96+
97+
@token_attr.setter
98+
def token_attr(self, token_attr: npt.ArrayLike | None) -> None:
99+
if token_attr is None:
100+
self._token_attr = None
101+
elif isinstance(token_attr, Tensor):
102+
self._token_attr = token_attr
103+
else:
104+
self._token_attr = torch.tensor(token_attr)
105+
106+
if self._token_attr is not None:
107+
# IDEA: in the future we might want to support higher dim seq_attr
108+
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
109+
assert self._token_attr.shape == (
110+
len(self.output_tokens),
111+
len(self.input_tokens),
112+
), dedent(
113+
f"""\
114+
Expect token_attr to have shape
115+
{len(self.output_tokens), len(self.input_tokens)},
116+
got {self._token_attr.shape}
117+
"""
118+
)
119+
120+
@property
121+
def output_probs(self) -> Optional[Tensor]:
122+
return self._output_probs
123+
124+
@output_probs.setter
125+
def output_probs(self, output_probs: npt.ArrayLike | None) -> None:
126+
if output_probs is None:
127+
self._output_probs = None
128+
elif isinstance(output_probs, Tensor):
129+
self._output_probs = output_probs
130+
else:
131+
self._output_probs = torch.tensor(output_probs)
132+
133+
if self._output_probs is not None:
134+
assert (
135+
len(self._output_probs.shape) == 1
136+
), "output_probs must be a 1D tensor"
137+
assert (
138+
len(self.output_tokens) == self._output_probs.shape[0]
139+
), "seq_attr and input_tokens must have the same length"
59140

60141
@property
61142
def seq_attr_dict(self) -> Dict[str, float]:
@@ -125,10 +206,14 @@ def plot_token_attr(
125206

126207
# Show all ticks and label them with the respective list entries.
127208
shortened_tokens = [
128-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
209+
shorten(repr(t)[1:-1], width=50, placeholder="...")
210+
for t in self.input_tokens
129211
]
130212
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
131-
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
213+
ax.set_yticks(
214+
np.arange(data.shape[0]),
215+
labels=[repr(token)[1:-1] for token in self.output_tokens],
216+
)
132217

133218
# Let the horizontal axes labeling appear on top.
134219
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
@@ -176,7 +261,8 @@ def plot_seq_attr(
176261
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
177262

178263
shortened_tokens = [
179-
shorten(t, width=50, placeholder="...") for t in self.input_tokens
264+
shorten(repr(t)[1:-1], width=50, placeholder="...")
265+
for t in self.input_tokens
180266
]
181267
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
182268

0 commit comments

Comments
 (0)