@@ -590,16 +590,19 @@ def multiplies_by_inputs(self) -> bool:
590
590
# for Lime child implementation.
591
591
592
592
593
- # pyre-fixme[3]: Return type must be annotated.
594
- # pyre-fixme[2]: Parameter must be annotated.
595
- def default_from_interp_rep_transform (curr_sample , original_inputs , ** kwargs ):
593
+ def default_from_interp_rep_transform (
594
+ curr_sample : Tensor ,
595
+ original_inputs : TensorOrTupleOfTensorsGeneric ,
596
+ ** kwargs : Any ,
597
+ ) -> TensorOrTupleOfTensorsGeneric :
598
+
596
599
assert (
597
600
"feature_mask" in kwargs
598
601
), "Must provide feature_mask to use default interpretable representation transform"
599
602
assert (
600
603
"baselines" in kwargs
601
604
), "Must provide baselines to use default interpretable representation transform"
602
- feature_mask = kwargs ["feature_mask" ]
605
+ feature_mask : TensorOrTupleOfTensorsGeneric = kwargs ["feature_mask" ]
603
606
if isinstance (feature_mask , Tensor ):
604
607
binary_mask = curr_sample [0 ][feature_mask ].bool ()
605
608
return (
@@ -610,10 +613,15 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
610
613
binary_mask = tuple (
611
614
curr_sample [0 ][feature_mask [j ]].bool () for j in range (len (feature_mask ))
612
615
)
613
- return tuple (
614
- binary_mask [j ].to (original_inputs [j ].dtype ) * original_inputs [j ]
615
- + (~ binary_mask [j ]).to (original_inputs [j ].dtype ) * kwargs ["baselines" ][j ]
616
- for j in range (len (feature_mask ))
616
+
617
+ return cast (
618
+ TensorOrTupleOfTensorsGeneric ,
619
+ tuple (
620
+ binary_mask [j ].to (original_inputs [j ].dtype ) * original_inputs [j ]
621
+ + (~ binary_mask [j ]).to (original_inputs [j ].dtype )
622
+ * kwargs ["baselines" ][j ]
623
+ for j in range (len (feature_mask ))
624
+ ),
617
625
)
618
626
619
627
@@ -652,9 +660,12 @@ def get_exp_kernel_similarity_function(
652
660
similarity_fn for Lime or LimeBase.
653
661
"""
654
662
655
- # pyre-fixme[3]: Return type must be annotated.
656
- # pyre-fixme[2]: Parameter must be annotated.
657
- def default_exp_kernel (original_inp , perturbed_inp , __ , ** kwargs ):
663
+ def default_exp_kernel (
664
+ original_inp : TensorOrTupleOfTensorsGeneric ,
665
+ perturbed_inp : TensorOrTupleOfTensorsGeneric ,
666
+ __ : Any ,
667
+ ** kwargs : Any ,
668
+ ) -> float :
658
669
flattened_original_inp = _flatten_tensor_or_tuple (original_inp ).float ()
659
670
flattened_perturbed_inp = _flatten_tensor_or_tuple (perturbed_inp ).float ()
660
671
if distance_mode == "cosine" :
0 commit comments