This page was generated from docs/source/tutorial/image-classification-vgg.ipynb
Interactive online version: launch binder Open in Colab

Image Classification with VGG

This tutorial will introduce the attribution of image classifiers using VGG11 on ImageNet. Feel free to replace VGG11 with any other version of VGG.

First, we install Zennit. This includes its dependencies Pillow, torch and torchvision:

[1]:
%pip install zennit
Requirement already satisfied: zennit in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (0.4.7)
Requirement already satisfied: click in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from zennit) (8.1.3)
Requirement already satisfied: Pillow in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from zennit) (9.2.0)
Requirement already satisfied: torchvision in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from zennit) (0.13.0)
Requirement already satisfied: torch>=1.7.0 in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from zennit) (1.12.0)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from zennit) (1.23.1)
Requirement already satisfied: typing-extensions in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from torch>=1.7.0->zennit) (4.3.0)
Requirement already satisfied: requests in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from torchvision->zennit) (2.28.1)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from requests->torchvision->zennit) (3.3)
Requirement already satisfied: charset-normalizer<3,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from requests->torchvision->zennit) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from requests->torchvision->zennit) (2022.6.15)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages (from requests->torchvision->zennit) (1.26.10)
Note: you may need to restart the kernel to use updated packages.

Then, we import necessary modules, classes and functions:

[2]:
import logging

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop
from torchvision.transforms import ToTensor, Normalize
from torchvision.models import vgg11_bn

from zennit.attribution import Gradient, SmoothGrad
from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox
from zennit.image import imgify, imsave
from zennit.torchvision import VGGCanonizer

We download an image of the Dornbusch Lighthouse from Wikimedia Commons:

[3]:
torch.hub.download_url_to_file(
    'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',
    'dornbusch-lighthouse.jpg',
)
100.0%

We load and prepare the data. The image is resized such that the shorter side is 256 pixels in size, then center-cropped to (224, 224), converted to a torch.Tensor, and then normalized according the channel-wise mean and standard deviation of the ImageNet dataset:

[4]:
# define the base image transform
transform_img = Compose([
    Resize(256),
    CenterCrop(224),
])
# define the full tensor transform
transform = Compose([
    transform_img,
    ToTensor(),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# load the image
image = Image.open('dornbusch-lighthouse.jpg')

# transform the PIL image and insert a batch-dimension
data = transform(image)[None]

We can look at the original image and the cropped image:

[5]:
# display the original image
display(image)
# display the resized and cropped image
display(transform_img(image))
../_images/tutorial_image-classification-vgg_9_0.png
../_images/tutorial_image-classification-vgg_9_1.png

Then, we initialize the model and load the hyperparameters. Set pretrained=True to use the pre-trained model instead of the random one:

[6]:
# load the model and set it to evaluation mode
model = vgg11_bn(pretrained=False).eval()
/home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/zennit/envs/0.4.7/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)

Compute the attribution using the EpsilonPlusFlat Composite:

[7]:
# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only
# needed with batch-norm)
canonizer = VGGCanonizer()

# create a composite, specifying the canonizers, if any
composite = EpsilonPlusFlat(canonizers=[canonizer])

# choose a target class for the attribution (label 437 is lighthouse)
target = torch.eye(1000)[[437]]

# create the attributor, specifying model and composite
with Gradient(model=model, composite=composite) as attributor:
    # compute the model output and attribution
    output, attribution = attributor(data, target)

print(f'Prediction: {output.argmax(1)[0].item()}')
Prediction: 374

Visualize the attribution:

[8]:
# sum over the channels
relevance = attribution.sum(1)

# create an image of the visualize attribution
img = imgify(relevance, symmetric=True, cmap='coldnhot')

# show the image
display(img)
../_images/tutorial_image-classification-vgg_15_0.png

Here, imgify produces a PIL-image, which can be saved with .save(). To directly save the visualized attribution, we can use imsave instead:

[9]:
# directly save the visualized attribution
imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')