Visualizing Results
While not limited to any domain in particular, attribution methods are most
commonly applied on 2-dimensional image data. For this reason, Zennit implements
a few functions to aid in the visualization of attributions of image data as
heatmaps. These methods may be found in zennit.image. To simply save
tensors that can be represented as images (1 or 3 channels, 2 dimensions), with
or without heatmap, zennit.image.imsave() may be used.
Let us consider the following setting which simulates image data:
import torch
from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten
from zennit.attribution import Gradient
# setup the model
model = Sequential(
Conv2d(3, 8, 3, padding=1),
ReLU(),
Conv2d(8, 16, 3, padding=1),
ReLU(),
Flatten(),
Linear(16 * 32 * 32, 1024),
ReLU(),
Linear(1024, 10),
)
# some random input data
input = torch.randn(8, 3, 32, 32, requires_grad=True)
# compute the gradient and output using the Gradient attributor
with Gradient(model) as attributor:
output, relevance = attributor(input)
The relevance has the same shape as the input, which here is (8, 3, 32, 32).
We can save the output and relevance, with all color-information intact, by
simply doing:
from zennit.image import imsave
for n, (inp, rel) in enumerate(zip(input, relevance)):
imsave(f'input_{n:03d}.png', inp.detach())
imsave(f'relevance_{n:03d}.png', rel)
Alternatively, the images may be composed as a grid, and saved as a single image:
imsave('input_grid.png', input.detach(), grid=True)
imsave('relevance_grid.png', relevance, grid=(2, 4))
The keyword argument grid may either be boolean, or the 2d shape of the image grid.
While this works well for the input, it is hard to interpret the attribution from the resulting images. Be aware that commonly input images are pre-processed before they are fed into networks. While clipping and scaling the image pose no problem for its visibility, normalization will change the look of the image greatly. Therefore, when saving images during training or inference, it is recommended to visualize input images either before applying the normalization, or after applying the inverse of the normalization.
imsave() uses zennit.image.imgify(), which,
given a numpy.ndarray or a torch.Tensor, will return a
Pillow image, which can also be used to quickly look at the image without saving
it:
from zennit.image import imgify
image = imgify(input.detach(), grid=True)
image.show()
Heatmap Normalization
Commonly, a heatmap of the attribution is produced by removing the color-channel
either by taking the (absolute) sum and normalizing to fit into an interval.
imsave() (through imgify()) will
shift and scale the input such that the full range of colors is used, using the
input’s minimum and maximum respectively. This can be tweaked by supplying the
vmin and vmax keyword arguments:
absrel = relevance.abs().sum(1)
# vmin and vmax works for both imsave and imgify
imsave('relevance_abs_0.png', absrel[0], vmin=0, vmax=absrel[0].amax())
image = imgify(absrel[0], vmin=0, vmax=absrel[0].amax())
image.show()
Another way to normalize the attribution which can be used with both
imsave() and imgify() is to use
the symmetric keyword argument, which provides two normalization strategies:
symmetric=False (default) and symmetric=True. Keep in mind that the
normalization of the attribution can greatly change how interpretable the
heatmap will be.
Let us consider a more interesting image to compare the two normalization strategies with signed and unsigned data:
from itertools import product
grid = torch.stack(torch.meshgrid(*((torch.linspace(-1, 1, 128),) * 2), indexing='xy'))
dist = ((grid + 0.25) ** 2).sum(0, keepdims=True) ** .5
ripples = (dist * 5 * torch.pi).cos().clip(-.5, 1.) * (-dist).exp()
for norm, sign in product(('symmetric', 'unaligned'), ('signed', 'absolute')):
array = ripples.abs() if sign == 'absolute' else ripples
symmetric = norm == 'symmetric'
imsave(f'ripples_{norm}_{sign}_bwr.png', array, symmetric=symmetric)
imsave(f'ripples_{norm}_{sign}_wred.png', array, symmetric=symmetric, cmap='wred')
The keyword argument cmap is used to control the color map.
|
|
|
|
|
|---|---|---|---|---|
signed |
absolute |
signed |
absolute |
|
|
|
|
|
|
|
|
|
|
|
Negative values were clipped to better see how the normalization modes work.
The default color map is 'bwr', which maps 0.0 to blue, 0.5 to white and 1.0 to
red, which means it is a signed color map, as the center of 0.5 is a
neutral point, with color intensities rising for values below and above.
Color map 'wred' maps 0.0 to white and 1.0 to red, which makes it
an unsigned color map, as its color intensity is monotonically increasing.
Using symmetric=False will simply map [min, max] to [0., 1.], i.e the
minimum value to 0.0, and the maximum value to 1.0. This works best with
unsigned color maps, when relevance is assumed to be monotonically increasing
and a value of 0.0 does not have any special meaning.
symmetric=True will find the absolute maximum per image, and will map the
input range [-absmax, absmax] to [0., 1.]. This means that the result
will be centered around 0.5, which works best with signed color maps (like
'bwr'), as positive (here red) and negative (here blue) intensities in the
produced heatmap are made comparable.
In the example above, our input is in the range [-0.5, 1.0]. If the negative
and positive values are meaningful (generally the case for attribution methods),
and the color map has a meaningful value at 0.5 (i.e. is signed),
symmetric=True is usually the best choice for normalization.
For symmetric=False the example above shows that with 'bwr' gives the
illusion of a shifted center, which makes it look like the attribution is
predominantly negative. Using the monotonic wred is normally the better
choice for the symmetric=False, but with signed attributions
the results are not as clear as they can be.
Finally, the example above shows the different outcomes when the input is
signed or its absolute is taken.
Using vmin and vmax overrides the minimum and maximum values
respectively determined by the normalization mode.
This means that, for example, using vmin=0 (and not setting vmax) with
symmetric=True will clip all values below 0.
Another useful setting is when the input is positive (or its absolute value was
taken) to use vmin=0 with symmetric=False, as this will give the full
range from 0 to the maximum value, since the smallest value may be larger than 0
when in cases where it is known that 0 would be the smallest possible value.
This shows the importance of the choice of the normalization and the color map.
Color Maps
Color maps play an essential role in the production of heatmaps which highlight
points of interest best. With the normalization modes we have seen the built-in
signed color map bwr (blue-white-red) and unsigned color map wred
(white-red).
All built-in color maps are defined in zennit.image.CMAPS.
The built-in unsigned color maps are:
Identifier |
CMSL-Source |
Visualization |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
and the built-in signed color maps are:
Identifier |
CMSL-Source |
Visualization |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CMSL-Source is the source code of the color map in Color-Map Specification
Language (CMSL). The color map for imsave() and
imgify() may be specified in one of three ways: the
identifier of a built-in color map (wred, coldnhot, …), a string
containing CMSL source-code, or a zennit.cmap.ColorMap instance.
from zennit.cmap import ColorMap
bar = torch.arange(256)[None].repeat((32, 1))
imsave('bar_wred.png', bar, cmap='wred')
imsave('bar_string.png', bar, cmap='ff0,00f')
cmap = ColorMap('000,f00,fff')
imsave('bar_cmap.png', bar, cmap=cmap)
Color-Map Specification Language
Color-Map Specification Language (CMSL) is a domain-specific language to
describe color maps in a quick and compact manner. It is implemented in
zennit.cmap. Color-maps can be compiled using the
zennit.cmap.ColorMap class, of which the constructor expects
CMSL source code as a string. Alternatively, a ColorMap instance may be obtained
using zennit.image.get_cmap(), which first looks up its argument string
in the built-in color-map dictionary zennit.image.CMAPS, and, if it
fails, tries to compile the string as CMSL source code.
from zennit.cmap import ColorMap
from zennit.image import get_cmap
bar = torch.arange(256)[None].repeat((32, 1))
cmap1 = ColorMap('000,a0:f00,fff')
cmap2 = get_cmap('1f:fff,f0f,000')
img1 = imgify(bar, cmap=cmap1)
img2 = imgify(bar, cmap=cmap2)
img1.show()
img2.show()
CMSL follows a simple grammar:
cmsl_cmap ::= color_node ("," color_node)+
color_node ::= [index ":"] rgb_color
index ::= half | full
rgb_color ::= half half half | full full full
full ::= half half
half ::= <single hex digit 0-9a-fA-F>
Values for both index and rgb_color are specified as hexadecimal values
with either one (half) or two (full) digits, where index consists of
a single value 0-255 (or half 0-15) and rgb_color consists of 3 values 0-255
(or half 0-15).
The index of all color_nodes must be in ascending order.
It describes the color-index of the color-map, where 00 (or half 0) is
the lowest value and ff (i.e. decimal 255, or half f) is the highest
value.
The same value of index may be repeated to produce hard color-transitions,
however, using the same value of index more than twice will only use the two
outermost color values.
If the indices of the first or last color_nodes are omitted, they will be
assumed as 00 and ff respectively.
Two additional color_nodes with the same color as the ones with lowest and
highest index will be implicitly created at indices 00 and ff
respectively, which means that if the lowest and/or highest specified color node
indices are larger or smaller than 00 or ff respectively, the colors
between 00 and the lowest index, and the highest index and ff will be
constant.
A color map needs at least two color_nodes (i.e., a useless single-color
color-map cannot be created by specifying a single color_node).
A color node will produce a color of its rgb_color for the value of its index.
Colors for values between two color nodes will be linearly interpolated between
their two colors, weighted by their respective proximity. Color nodes without
indices will evenly spaced between color nodes with indices. The first and last
color nodes, if not equipped with an index, will be assumed as 00 and ff
respectively.
While technically there does not exist a syntactic difference between signed
and unsigned color maps, signed color maps often require a color node at the
central index 80, while unsigned color maps should have monotonically
increasing or decreasing intensities, which can be most easily done by only
specifying two color nodes.
The built-in color map cold could be understood as a signed color map,
since it has an explicit color node blue at its center. Visually, however,
due to its monotonicity, it is hard to interpret as such.
The following shows a few examples of color maps along their CMSL source code:
CMSL-Source |
Visualization |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
Additionally, zennit.cmap.LazyColorMapCache may be used to define
color maps in bulk, and lazily compile them when they are accessed the first
time. This is the way the built-in color maps are defined in
zennit.image.CMAPS.
from zennit.cmap import LazyColorMapCache
cmaps = LazyColorMapCache({
'reds': '111,f11',
'blues': '111,11f',
'greens': '111,1f1',
})
img = imgify(ripples, cmap=cmaps['greens'])
img.show()
LazyColorMapCache stores the specified source code for
each key, and if accessed with cmaps[key], it will either compile the
ColorMap, cache it if it has not been accessed
before and return it, or it will return the previously cached
ColorMap.
Changing Palettes
When using imgify() (or
imsave()), arrays with a single channel are converted
to PIL images in palette mode (P), where the palette specifies the color
map. This means that the color map of an image may be changed later without
modifying its values. The palette for a color map can be generated using its
zennit.cmap.ColorMap.palette() method.
palette() accepts an optional argument level
(default 1.0), with which the resulting palette can be either stretched or
compressed, resulting in heatmaps where either the maximum value threshold is
moved closer to the center (level > 1.0) or farther away from it (0.0 < level
< 1.0). A value of level=2.0 proved to better highlight high values of
a heatmap in print.
img = imgify(ripples, symmetric=True)
img.show()
cmap = ColorMap('111,1f1')
pal = cmap.palette(level=1.0)
img.putpalette(pal)
img.show()
The convenience function zennit.image.palette() may also be used to
directly get the palette from a built-in color map name or CMSL source code.
This way, existing PNG-files of heatmaps may thus also be modified to use different color maps by changing their palette:
from PIL import Image
from zennit.image import palette
# store a heatmap
fname = 'newheatmap.png'
imsave(fname, ripples, symmetric=True)
# load the heatmap, change the palette and write it to the same file
img = Image.open(fname)
img = img.convert('P')
pal = palette('f1f,111,ff1', level=1.0)
img.putpalette(pal)
img.save(fname)
A utility CLI script which changes the color map is provided in share/scripts/palette_swap.py, which can be used in the following way:
$ python share/scripts/palette_swap.py newheatmap \
--cmap 'f1f,111,ff1' \
--level 1.0
























