Skip to content

Commit ed5daa3

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Address remaining mypy errors (#1383)
Summary: Pull Request resolved: #1383 Misc fixes to remaining mypy errors Reviewed By: vivekmig Differential Revision: D64518879 fbshipit-source-id: 7245cbd3b49fb3a8a4b5ddbeecedc0dc53c0ef6a
1 parent fd985a3 commit ed5daa3

File tree

5 files changed

+9
-6
lines changed

5 files changed

+9
-6
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def attribute(
508508
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
509509
else:
510510
skip_tokens = []
511+
skip_tokens = cast(List[int], skip_tokens)
511512

512513
if isinstance(target, str):
513514
encoded = self.tokenizer.encode(target)
@@ -700,6 +701,7 @@ def attribute(
700701
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
701702
else:
702703
skip_tokens = []
704+
skip_tokens = cast(List[int], skip_tokens)
703705

704706
if isinstance(target, str):
705707
encoded = self.tokenizer.encode(target)

captum/attr/_core/occlusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def _occlusion_mask(
384384
def _get_feature_range_and_mask(
385385
self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any
386386
) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]:
387-
feature_max = np.prod(kwargs["shift_counts"])
387+
feature_max = int(np.prod(kwargs["shift_counts"]))
388388
return 0, feature_max, None
389389

390390
def _get_feature_counts(

captum/attr/_utils/attribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ def multiplies_by_inputs(self) -> bool:
367367
return True
368368

369369

370-
class InternalAttribution(Attribution, Generic[ModuleOrModuleList]):
370+
# mypy false positive "Free type variable expected in Generic[...]" but
371+
# ModuleOrModuleList is a TypeVar
372+
class InternalAttribution(Attribution, Generic[ModuleOrModuleList]): # type: ignore
371373
r"""
372374
Shared base class for LayerAttrubution and NeuronAttribution,
373375
attribution types that require a model and a particular layer.

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __init__(
189189
self.vectorize = vectorize
190190

191191
# TODO: restore prior state
192-
self.final_fc_layer = final_fc_layer # type: ignore
192+
self.final_fc_layer = cast(Module, final_fc_layer)
193193
for param in self.final_fc_layer.parameters():
194194
param.requires_grad = True
195195

@@ -212,8 +212,7 @@ def final_fc_layer(self) -> Module:
212212
return self._final_fc_layer
213213

214214
@final_fc_layer.setter
215-
# pyre-fixme[3]: Return type must be annotated.
216-
def final_fc_layer(self, layer: Union[Module, str]):
215+
def final_fc_layer(self, layer: Union[Module, str]) -> None:
217216
if isinstance(layer, str):
218217
try:
219218
self._final_fc_layer = _get_module_from_name(self.model, layer)

tests/attr/test_interpretable_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
2525
# pyre-fixme[43]: Incompatible overload. The implementation of
2626
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
2727
# pyre-ignore[11]: Annotation `pt` is not defined as a type
28-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
28+
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... # type: ignore # noqa: E501 line too long
2929

3030
def encode(
3131
self, text: str, return_tensors: Optional[str] = "pt"

0 commit comments

Comments
 (0)