|
5 | 5 | from abc import ABC
|
6 | 6 | from copy import copy
|
7 | 7 | from dataclasses import dataclass
|
8 |
| -from textwrap import shorten |
| 8 | +from textwrap import dedent, shorten |
9 | 9 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
|
10 | 10 |
|
11 | 11 | import matplotlib.colors as mcolors
|
12 | 12 |
|
13 | 13 | import matplotlib.pyplot as plt
|
14 | 14 | import numpy as np
|
| 15 | +import numpy.typing as npt |
15 | 16 |
|
16 | 17 | import torch
|
17 | 18 | from captum._utils.typing import TokenizerLike
|
@@ -51,11 +52,91 @@ class LLMAttributionResult:
|
51 | 52 | It also provides utilities to help present and plot the result in different forms.
|
52 | 53 | """
|
53 | 54 |
|
54 |
| - seq_attr: Tensor |
55 |
| - token_attr: Optional[Tensor] |
56 | 55 | input_tokens: List[str]
|
57 | 56 | output_tokens: List[str]
|
58 |
| - output_probs: Optional[Tensor] = None |
| 57 | + # pyre-ignore[13]: initialized via a property setter |
| 58 | + _seq_attr: Tensor |
| 59 | + _token_attr: Optional[Tensor] = None |
| 60 | + _output_probs: Optional[Tensor] = None |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + seq_attr: npt.ArrayLike, |
| 65 | + token_attr: Optional[npt.ArrayLike], |
| 66 | + output_probs: Optional[npt.ArrayLike], |
| 67 | + input_tokens: List[str], |
| 68 | + output_tokens: List[str], |
| 69 | + ) -> None: |
| 70 | + self.input_tokens = input_tokens |
| 71 | + self.output_tokens = output_tokens |
| 72 | + self.seq_attr = seq_attr |
| 73 | + self.token_attr = token_attr |
| 74 | + self.output_probs = output_probs |
| 75 | + |
| 76 | + @property |
| 77 | + def seq_attr(self) -> Tensor: |
| 78 | + return self._seq_attr |
| 79 | + |
| 80 | + @seq_attr.setter |
| 81 | + def seq_attr(self, seq_attr: npt.ArrayLike) -> None: |
| 82 | + if isinstance(seq_attr, Tensor): |
| 83 | + self._seq_attr = seq_attr |
| 84 | + else: |
| 85 | + self._seq_attr = torch.tensor(seq_attr) |
| 86 | + # IDEA: in the future we might want to support higher dim seq_attr |
| 87 | + # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) |
| 88 | + assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor" |
| 89 | + assert ( |
| 90 | + len(self.input_tokens) == self._seq_attr.shape[0] |
| 91 | + ), "seq_attr and input_tokens must have the same length" |
| 92 | + |
| 93 | + @property |
| 94 | + def token_attr(self) -> Optional[Tensor]: |
| 95 | + return self._token_attr |
| 96 | + |
| 97 | + @token_attr.setter |
| 98 | + def token_attr(self, token_attr: npt.ArrayLike | None) -> None: |
| 99 | + if token_attr is None: |
| 100 | + self._token_attr = None |
| 101 | + elif isinstance(token_attr, Tensor): |
| 102 | + self._token_attr = token_attr |
| 103 | + else: |
| 104 | + self._token_attr = torch.tensor(token_attr) |
| 105 | + |
| 106 | + if self._token_attr is not None: |
| 107 | + # IDEA: in the future we might want to support higher dim seq_attr |
| 108 | + assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor" |
| 109 | + assert self._token_attr.shape == ( |
| 110 | + len(self.output_tokens), |
| 111 | + len(self.input_tokens), |
| 112 | + ), dedent( |
| 113 | + f"""\ |
| 114 | + Expect token_attr to have shape |
| 115 | + {len(self.output_tokens), len(self.input_tokens)}, |
| 116 | + got {self._token_attr.shape} |
| 117 | + """ |
| 118 | + ) |
| 119 | + |
| 120 | + @property |
| 121 | + def output_probs(self) -> Optional[Tensor]: |
| 122 | + return self._output_probs |
| 123 | + |
| 124 | + @output_probs.setter |
| 125 | + def output_probs(self, output_probs: npt.ArrayLike | None) -> None: |
| 126 | + if output_probs is None: |
| 127 | + self._output_probs = None |
| 128 | + elif isinstance(output_probs, Tensor): |
| 129 | + self._output_probs = output_probs |
| 130 | + else: |
| 131 | + self._output_probs = torch.tensor(output_probs) |
| 132 | + |
| 133 | + if self._output_probs is not None: |
| 134 | + assert ( |
| 135 | + len(self._output_probs.shape) == 1 |
| 136 | + ), "output_probs must be a 1D tensor" |
| 137 | + assert ( |
| 138 | + len(self.output_tokens) == self._output_probs.shape[0] |
| 139 | + ), "seq_attr and input_tokens must have the same length" |
59 | 140 |
|
60 | 141 | @property
|
61 | 142 | def seq_attr_dict(self) -> Dict[str, float]:
|
@@ -125,10 +206,14 @@ def plot_token_attr(
|
125 | 206 |
|
126 | 207 | # Show all ticks and label them with the respective list entries.
|
127 | 208 | shortened_tokens = [
|
128 |
| - shorten(t, width=50, placeholder="...") for t in self.input_tokens |
| 209 | + shorten(repr(t)[1:-1], width=50, placeholder="...") |
| 210 | + for t in self.input_tokens |
129 | 211 | ]
|
130 | 212 | ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
|
131 |
| - ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) |
| 213 | + ax.set_yticks( |
| 214 | + np.arange(data.shape[0]), |
| 215 | + labels=[repr(token)[1:-1] for token in self.output_tokens], |
| 216 | + ) |
132 | 217 |
|
133 | 218 | # Let the horizontal axes labeling appear on top.
|
134 | 219 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
|
@@ -176,7 +261,8 @@ def plot_seq_attr(
|
176 | 261 | fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
|
177 | 262 |
|
178 | 263 | shortened_tokens = [
|
179 |
| - shorten(t, width=50, placeholder="...") for t in self.input_tokens |
| 264 | + shorten(repr(t)[1:-1], width=50, placeholder="...") |
| 265 | + for t in self.input_tokens |
180 | 266 | ]
|
181 | 267 | ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
|
182 | 268 |
|
|
0 commit comments