|
5 | 5 | from abc import ABC
|
6 | 6 | from copy import copy
|
7 | 7 | from dataclasses import dataclass
|
8 |
| -from textwrap import shorten |
| 8 | +from textwrap import dedent, shorten |
9 | 9 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
|
10 | 10 |
|
11 | 11 | import matplotlib.colors as mcolors
|
12 | 12 |
|
13 | 13 | import matplotlib.pyplot as plt
|
14 | 14 | import numpy as np
|
| 15 | +import numpy.typing as npt |
15 | 16 |
|
16 | 17 | import torch
|
17 | 18 | from captum._utils.typing import TokenizerLike
|
@@ -51,11 +52,75 @@ class LLMAttributionResult:
|
51 | 52 | It also provides utilities to help present and plot the result in different forms.
|
52 | 53 | """
|
53 | 54 |
|
54 |
| - seq_attr: Tensor |
55 |
| - token_attr: Optional[Tensor] |
56 | 55 | input_tokens: List[str]
|
57 | 56 | output_tokens: List[str]
|
| 57 | +<<<<<<< dest: 0573f2a25ed3 - ivanov: Cleanup: remove outdated TODO comment |
58 | 58 | output_probs: Optional[Tensor] = None
|
| 59 | +||||||| base: ec37f4ace4b6 - ymatharu: [SS Omni] Generate explicit features... |
| 60 | +======= |
| 61 | + # pyre-ignore[13]: initialized via a property setter |
| 62 | + _seq_attr: Tensor |
| 63 | + _token_attr: Optional[Tensor] = None |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + seq_attr: npt.ArrayLike, |
| 68 | + token_attr: Optional[npt.ArrayLike], |
| 69 | + input_tokens: List[str], |
| 70 | + output_tokens: List[str], |
| 71 | + ) -> None: |
| 72 | + self.input_tokens = input_tokens |
| 73 | + self.output_tokens = output_tokens |
| 74 | + self.seq_attr = seq_attr |
| 75 | + self.token_attr = token_attr |
| 76 | + |
| 77 | + @property |
| 78 | + def seq_attr(self) -> Tensor: |
| 79 | + return self._seq_attr |
| 80 | + |
| 81 | + @seq_attr.setter |
| 82 | + def seq_attr(self, seq_attr: npt.ArrayLike) -> None: |
| 83 | + if isinstance(seq_attr, Tensor): |
| 84 | + self._seq_attr = seq_attr |
| 85 | + else: |
| 86 | + self._seq_attr = torch.tensor(seq_attr) |
| 87 | + # IDEA: in the future we might want to support higher dim seq_attr |
| 88 | + # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) |
| 89 | + assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor" |
| 90 | + |
| 91 | + assert ( |
| 92 | + len(self.input_tokens) == self._seq_attr.shape[0] |
| 93 | + ), "seq_attr and input_tokens must have the same length" |
| 94 | + |
| 95 | + @property |
| 96 | + def token_attr(self) -> Optional[Tensor]: |
| 97 | + return self._token_attr |
| 98 | + |
| 99 | + @token_attr.setter |
| 100 | + def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: |
| 101 | + if isinstance(token_attr, Tensor): |
| 102 | + self._token_attr = token_attr |
| 103 | + elif token_attr is None: |
| 104 | + # can't combine with previous clause, linter unhappy ¯\_(ツ)_/¯ |
| 105 | + self._token_attr = None |
| 106 | + else: |
| 107 | + self._token_attr = torch.tensor(token_attr) |
| 108 | + # IDEA: in the future we might want to support higher dim seq_attr |
| 109 | + if self._token_attr is not None: |
| 110 | + assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor" |
| 111 | + |
| 112 | + if self._token_attr is not None: |
| 113 | + assert self._token_attr.shape == ( |
| 114 | + len(self.output_tokens), |
| 115 | + len(self.input_tokens), |
| 116 | + ), dedent( |
| 117 | + f"""\ |
| 118 | + Expect token_attr to have shape |
| 119 | + {len(self.output_tokens), len(self.input_tokens)}, |
| 120 | + got {self._token_attr.shape} |
| 121 | + """ |
| 122 | + ) |
| 123 | +>>>>>>> source: 5b076627efd3 - fymeng: [Captum] Allow more variable input to ... |
59 | 124 |
|
60 | 125 | @property
|
61 | 126 | def seq_attr_dict(self) -> Dict[str, float]:
|
@@ -125,10 +190,14 @@ def plot_token_attr(
|
125 | 190 |
|
126 | 191 | # Show all ticks and label them with the respective list entries.
|
127 | 192 | shortened_tokens = [
|
128 |
| - shorten(t, width=50, placeholder="...") for t in self.input_tokens |
| 193 | + shorten(repr(t)[1:-1], width=50, placeholder="...") |
| 194 | + for t in self.input_tokens |
129 | 195 | ]
|
130 | 196 | ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
|
131 |
| - ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) |
| 197 | + ax.set_yticks( |
| 198 | + np.arange(data.shape[0]), |
| 199 | + labels=[repr(token)[1:-1] for token in self.output_tokens], |
| 200 | + ) |
132 | 201 |
|
133 | 202 | # Let the horizontal axes labeling appear on top.
|
134 | 203 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
|
@@ -176,7 +245,8 @@ def plot_seq_attr(
|
176 | 245 | fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
|
177 | 246 |
|
178 | 247 | shortened_tokens = [
|
179 |
| - shorten(t, width=50, placeholder="...") for t in self.input_tokens |
| 248 | + shorten(repr(t)[1:-1], width=50, placeholder="...") |
| 249 | + for t in self.input_tokens |
180 | 250 | ]
|
181 | 251 | ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
|
182 | 252 |
|
|
0 commit comments