Skip to content

Visualization of Model

Kashu edited this page Aug 9, 2023 · 4 revisions

GradCAM

We can obtain "visual explanations" for decisions of a model using GradCAM. You can install the package by pip install grad-cam.

This package assumes the model will output a tensor. We can wrap the model if the model outputs dict, tuple, list, etc.

class ModelOutputWrapper(nn.Module):
    def __init__(self, model): 
        super().__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)["out"]

Class Activation Maps for Semantic Segmentation

The following annotated example shows Class Activation Maps for semantic segmentation, using deeplabv3_resnet50 from torchvision. (Full Tutorial)

from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision.models.segmentation import deeplabv3_resnet50
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

# Prepare the input image (can be batch of images)
image = np.asarray(Image.open(IMAGE_PATH)) / 255
input_tensor = preprocess_image(image,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Define a model and load weights (wrap the model if the output is not a tensor)
model = deeplabv3_resnet50(pretrained=True, progress=False)
model = ModelOutputWrapper(model)
model = model.eval()

class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = torch.from_numpy(mask)
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()
        
    def __call__(self, model_output):
        return (model_output[self.category, :, : ] * self.mask).sum()

# Specify which layer to visualize
target_layers = [model.model.backbone.layer4]
CLASS_IDX = ... # index of the class
SEG_MASK = ... # pred mask for the class (np.float32(pred_mask.argmax(axis=0) == class))

targets = [SemanticSegmentationTarget(CLASS_IDX, SEG_MASK)]

# Construct the CAM object once, and then re-use it on many images:
with GradCAM(model=model,
             target_layers=target_layers,
             use_cuda=torch.cuda.is_available()) as cam:
    grayscale_cam = cam(input_tensor=input_tensor,targets=targets)[0, :]
    cam_image = show_cam_on_image(image, grayscale_cam, use_rgb=True)

Image.fromarray(cam_image)

Clone this wiki locally