Skip to content

Commit af89779

Browse files
Dave Feltenbergerfacebook-github-bot
authored andcommitted
Fix some pyre linter errors (#1615)
Summary: Pull Request resolved: #1615 Fix missing type information in lime.py so pyre doesn't complain :) Reviewed By: vivekmig, Ayush-Warikoo Differential Revision: D76843933 fbshipit-source-id: ba8105df4eacd91b3d6b1f22f7ab9fc4a8de686e
1 parent a293a56 commit af89779

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

captum/attr/_core/lime.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -590,16 +590,19 @@ def multiplies_by_inputs(self) -> bool:
590590
# for Lime child implementation.
591591

592592

593-
# pyre-fixme[3]: Return type must be annotated.
594-
# pyre-fixme[2]: Parameter must be annotated.
595-
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
593+
def default_from_interp_rep_transform(
594+
curr_sample: Tensor,
595+
original_inputs: TensorOrTupleOfTensorsGeneric,
596+
**kwargs: Any,
597+
) -> TensorOrTupleOfTensorsGeneric:
598+
596599
assert (
597600
"feature_mask" in kwargs
598601
), "Must provide feature_mask to use default interpretable representation transform"
599602
assert (
600603
"baselines" in kwargs
601604
), "Must provide baselines to use default interpretable representation transform"
602-
feature_mask = kwargs["feature_mask"]
605+
feature_mask: TensorOrTupleOfTensorsGeneric = kwargs["feature_mask"]
603606
if isinstance(feature_mask, Tensor):
604607
binary_mask = curr_sample[0][feature_mask].bool()
605608
return (
@@ -610,10 +613,15 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
610613
binary_mask = tuple(
611614
curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
612615
)
613-
return tuple(
614-
binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
615-
+ (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
616-
for j in range(len(feature_mask))
616+
617+
return cast(
618+
TensorOrTupleOfTensorsGeneric,
619+
tuple(
620+
binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
621+
+ (~binary_mask[j]).to(original_inputs[j].dtype)
622+
* kwargs["baselines"][j]
623+
for j in range(len(feature_mask))
624+
),
617625
)
618626

619627

@@ -652,9 +660,12 @@ def get_exp_kernel_similarity_function(
652660
similarity_fn for Lime or LimeBase.
653661
"""
654662

655-
# pyre-fixme[3]: Return type must be annotated.
656-
# pyre-fixme[2]: Parameter must be annotated.
657-
def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
663+
def default_exp_kernel(
664+
original_inp: TensorOrTupleOfTensorsGeneric,
665+
perturbed_inp: TensorOrTupleOfTensorsGeneric,
666+
__: Any,
667+
**kwargs: Any,
668+
) -> float:
658669
flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float()
659670
flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float()
660671
if distance_mode == "cosine":

0 commit comments

Comments
 (0)