-
Notifications
You must be signed in to change notification settings - Fork 18
Visualization of Model
Kashu edited this page Aug 9, 2023
·
4 revisions
We can obtain "visual explanations" for decisions of a model using GradCAM. You can install the package by pip install grad-cam
.
This package assumes the model will output a tensor. We can wrap the model if the model outputs dict
, tuple
, list
, etc.
class ModelOutputWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)["out"]
The following annotated example shows Class Activation Maps for semantic segmentation, using deeplabv3_resnet50 from torchvision. (Full Tutorial)
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision.models.segmentation import deeplabv3_resnet50
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
# Prepare the input image (can be batch of images)
image = np.asarray(Image.open(IMAGE_PATH)) / 255
input_tensor = preprocess_image(image,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# Define a model and load weights (wrap the model if the output is not a tensor)
model = deeplabv3_resnet50(pretrained=True, progress=False)
model = ModelOutputWrapper(model)
model = model.eval()
class SemanticSegmentationTarget:
def __init__(self, category, mask):
self.category = category
self.mask = torch.from_numpy(mask)
if torch.cuda.is_available():
self.mask = self.mask.cuda()
def __call__(self, model_output):
return (model_output[self.category, :, : ] * self.mask).sum()
# Specify which layer to visualize
target_layers = [model.model.backbone.layer4]
CLASS_IDX = ... # index of the class
SEG_MASK = ... # pred mask for the class (np.float32(pred_mask.argmax(axis=0) == class))
targets = [SemanticSegmentationTarget(CLASS_IDX, SEG_MASK)]
# Construct the CAM object once, and then re-use it on many images:
with GradCAM(model=model,
target_layers=target_layers,
use_cuda=torch.cuda.is_available()) as cam:
grayscale_cam = cam(input_tensor=input_tensor,targets=targets)[0, :]
cam_image = show_cam_on_image(image, grayscale_cam, use_rgb=True)
Image.fromarray(cam_image)