Skip to content

Commit cc5f468

Browse files
dzenanzfacebook-github-bot
authored andcommitted
Add split_channels parameter to LayerGradCam.attribute (#1086)
Summary: This allows examination of each channel's contribution. That is useful if channels are something other than standard RGB, for example multi-spectral input, potentially with many spectral channels. Pull Request resolved: #1086 Reviewed By: vivekmig Differential Revision: D42221000 Pulled By: NarineK fbshipit-source-id: 1b04276d68e4a22a1d7338bd80436b118268d787
1 parent b398d52 commit cc5f468

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

captum/attr/_core/layer/grad_cam.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def attribute(
8282
additional_forward_args: Any = None,
8383
attribute_to_layer_input: bool = False,
8484
relu_attributions: bool = False,
85+
attr_dim_summation: bool = True,
8586
) -> Union[Tensor, Tuple[Tensor, ...]]:
8687
r"""
8788
Args:
@@ -149,6 +150,10 @@ def attribute(
149150
otherwise, by default, both positive and negative
150151
attributions are returned.
151152
Default: False
153+
attr_dim_summation (bool, optional): Indicates whether to
154+
sum attributions along dimension 1 (usually channel).
155+
The default (True) means to sum along dimension 1.
156+
Default: True
152157
153158
Returns:
154159
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -208,10 +213,17 @@ def attribute(
208213
for layer_grad in layer_gradients
209214
)
210215

211-
scaled_acts = tuple(
212-
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
213-
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
214-
)
216+
if attr_dim_summation:
217+
scaled_acts = tuple(
218+
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
219+
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
220+
)
221+
else:
222+
scaled_acts = tuple(
223+
summed_grad * layer_eval
224+
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
225+
)
226+
215227
if relu_attributions:
216228
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
217229
return _format_output(len(scaled_acts) > 1, scaled_acts)

tests/attr/layer/test_grad_cam.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ def test_simple_input_conv(self) -> None:
3333
net, net.conv1, inp, [[[[11.25, 13.5], [20.25, 22.5]]]]
3434
)
3535

36+
def test_simple_input_conv_split_channels(self) -> None:
37+
net = BasicModel_ConvNet_One_Conv()
38+
inp = torch.arange(16).view(1, 1, 4, 4).float()
39+
expected_result = [
40+
[
41+
[[-3.7500, 3.0000], [23.2500, 30.0000]],
42+
[[15.0000, 10.5000], [-3.0000, -7.5000]],
43+
]
44+
]
45+
self._grad_cam_test_assert(
46+
net,
47+
net.conv1,
48+
inp,
49+
expected_activation=expected_result,
50+
attr_dim_summation=False,
51+
)
52+
3653
def test_simple_input_conv_no_grad(self) -> None:
3754
net = BasicModel_ConvNet_One_Conv()
3855

@@ -100,6 +117,7 @@ def _grad_cam_test_assert(
100117
additional_input: Any = None,
101118
attribute_to_layer_input: bool = False,
102119
relu_attributions: bool = False,
120+
attr_dim_summation: bool = True,
103121
):
104122
layer_gc = LayerGradCam(model, target_layer)
105123
self.assertFalse(layer_gc.multiplies_by_inputs)
@@ -109,6 +127,7 @@ def _grad_cam_test_assert(
109127
additional_forward_args=additional_input,
110128
attribute_to_layer_input=attribute_to_layer_input,
111129
relu_attributions=relu_attributions,
130+
attr_dim_summation=attr_dim_summation,
112131
)
113132
assertTensorTuplesAlmostEqual(
114133
self, attributions, expected_activation, delta=0.01

0 commit comments

Comments
 (0)