Using Rules, Composites, and Canonizers

Zennit implements propagation-based attribution methods by overwriting the gradient of PyTorch modules within PyTorch’s auto-differentiation engine. There are three building blocks in Zennit to achieve attributions: Rules, Composites and Canonizers. In short, Rules specify how to overwrite the gradient, Composites map rules to modules, and Canonizers transform some module types and configurations to a canonical form, necessary in some cases.

Rules

Rules are used to overwrite the gradient of individual modules by adding forward hooks to modules, which track the input and output torch.Tensors and add gradient hooks to the tensors to intercept normal gradient computation. All available built-in rules can be found in zennit.rules, some of which are:

  • Epsilon, the most basic LRP rule

  • AlphaBeta, an LRP rule that splits positive and negative contributions and weights them

  • ZPlus, an LRP rule that only takes positive contributions into account

  • Gamma, an LRP rule that amplifies the positive weights

  • ZBox, an LRP rule for bounded inputs, like pixel space, commonly used in the first layer of a network

  • ReLUGuidedBackprop, a rule for ReLU activations to only propagate positive gradients back

  • Norm, a rule which distributes the gradient weighted by each output’s fraction of the full output

  • Pass, a rule which passes the incoming gradient on without changing it

  • Flat, a rule meant for linear (dense, convolutional) layers to equally distribute relevance as if inputs and weights were constant

Rules can be instantiated, after which they may be used directly:

import torch
from torch.nn import Conv2d
from zennit.rules import Epsilon

# instantiate the rule
rule = Epsilon(epsilon=1e-5)
conv_layer = Conv2d(3, 10, 3, padding=1)

# registering a rule adds hooks to the module which temporarily overwrites
# its gradient computation; handles are returned to remove the hooks to undo
# the modification
handles = rule.register(conv_layer)

# to compute the gradient (i.e. the attribution), requires_grad must be True
input = torch.randn(1, 3, 32, 32, requires_grad=True)
output = conv_layer(input)

# torch.autograd.grad returns a tuple, the comma after `attribution`
# unpacks the single element in the tuple; the `grad_outputs` are necessary
# for non-scalar outputs, and can be used to target which output should be
# attributed for; `ones_like` targets all outputs
attribution, = torch.autograd.grad(
    output, input, grad_outputs=torch.ones_like(output)
)

# remove the hooks
handles.remove()

See Writing Custom Rules for further technical detail on how to write custom rules.

Note that some rules, in particular the ones that modify parameters (e.g. ZPlus, AlphaBeta, …) are not thread-safe in the backward-phase, because they modify the model parameters for a brief moment. For most users, this is unlikely to cause any problems, and may be avoided by using locks in appropriate locations.

Composites

For a model with multiple layers, it may be inconvenient to register each rule individually. Therefore, Composites are used to map rules to layers given various criterions. Composites also take care of registering all models, and removing their handles after use. All available Composites can be found in zennit.composites.

Some built-in composites implement rule-mappings needed for some common attribution methods, some of which are

Additionally, the Norm rule, which normalizes the gradient by output fraction, is used for Sum and AvgPool layers in all of the listed Composites except for GuidedBackprop.

Since the gradient is only overwritten by Rules, the gradient will be unchanged for layers without applicable rules. If layers should only pass their received gradient/relevance on, the Pass rule should be used (which is done for all activations in all LRP Composites, but not in GuidedBackprop or ExcitationBackprop).

Note on MaxPool: For LRP, the gradient of MaxPool assigns values only to the largest inputs (winner-takes-all), which is already the expected behaviour for LRP rules.

Composites may require arguments, e.g. EpsilonGammaBox requires keyword arguments high and low to specify the bounds of the first layer’s ZBox.

import torch
from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten
from zennit.composites import EpsilonGammaBox

# setup the model
model = Sequential(
    Conv2d(3, 8, 3, padding=1),
    ReLU(),
    Conv2d(8, 16, 3, padding=1),
    ReLU(),
    Flatten(),
    Linear(16 * 32 * 32, 1024),
    ReLU(),
    Linear(1024, 10),
)
# sigma of normal distribution, just for visual purposes
sigma = 1.
# some random input data, still requires grad
input = torch.randn(1, 3, 32, 32, requires_grad=True) * sigma

# low and high values for ZBox need to be Tensors in the shape of the input
# the batch-dimension may be chosen larger, to support different sizes
composite = EpsilonGammaBox(
    low=torch.full_like(input, -3 * sigma),
    high=torch.full_like(input, 3 * sigma)
)

There are two basic ways using only the Composite to register the modules, either using register():

# register hooks for rules to all modules that apply
composite.register(model)
# execute the hooked/modified model
output = model(input)
# compute the attribution via the gradient
attribution, = torch.autograd.grad(
    output, input, grad_outputs=torch.ones_like(output)
)
# remove all hooks, undoing the modification
composite.remove()

and using context():

# register hooks for rules to all modules that apply within the context
# note that model and modified_model are the same model, the context
# variable is purely visual
# hooks are removed when the context is exited
with composite.context(model) as modified_model:
    # execute the hooked/modified model
    output = modified_model(input)
    # compute the attribution via the gradient
    attribution, = torch.autograd.grad(
        output, input, grad_outputs=torch.ones_like(output)
    )

There is a third option using zennit.attribution.Attributor, which is explained in Using Attributors.

Finally, there are abstract Composites which may be used to specify custom Composites:

For example, the built-in EpsilonPlus composite may be written like the following:

from zennit.composites import LayerMapComposite
from zennit.rules import Epsilon, ZPlus, Norm, Pass
from zennit.types import Convolution, Activation, AvgPool

# the layer map is a list of tuples, where the first element is the target
# layer type, and the second is the rule template
layer_map = [
    (Activation, Pass()),  # ignore activations
    (AvgPool, Norm()),  # normalize relevance for any AvgPool
    (Convolution, ZPlus()),  # any convolutional layer
    (Linear, Epsilon(epsilon=1e-6))  # this is the dense Linear, not any
]
composite = LayerMapComposite(layer_map=layer_map)

Note that rules used in composites are only used as templates and copied for each layer they apply to using zennit.core.Hook.copy(). If we want to map the ZBox rule to the first convolutional layer, we can use SpecialFirstLayerMapComposite instead:

from zennit.composites import SpecialFirstLayerMapComposite
from zennit.rules import ZBox
# abstract base class to describe convolutions + dense linear layers
from zennit.types import Linear as AnyLinear

# shape of our data
shape = (1, 3, 32, 32)
low = torch.full(shape, -3)
high = torch.full(shape, 3)
# the first map is only used once, to the first module which applies to the
# map, i.e. here the first layer of type AnyLinear
first_map = [
    (AnyLinear, ZBox(low, high))
]
# layer_map is used from the previous example
composite = SpecialFirstLayerMapComposite(
    layer_map=layer_map, first_map=first_map
)

If a composite is made to apply for a single model, a NameMapComposite can provide a transparent mapping from module name to rule:

from collections import OrderedDict
from zennit.composites import NameMapComposite

# setup the model, explicitly naming them
model = Sequential(OrderedDict([
    ('conv0', Conv2d(3, 8, 3, padding=1)),
    ('relu0', ReLU()),
    ('conv1', Conv2d(8, 16, 3, padding=1)),
    ('relu1', ReLU()),
    ('flatten', Flatten()),
    ('linear0', Linear(16 * 32 * 32, 1024)),
    ('relu2', ReLU()),
    ('linear1', Linear(1024, 10)),
]))

# look at the available modules
print(list(model.named_modules()))

# manually write a rule mapping:
composite = NameMapComposite([
    (['conv0'], ZBox(low, high)),
    (['conv1'], ZPlus()),
    (['linear0', 'linear1'], Epsilon()),
])

Modules built using torch.nn.Sequential without explicit names will have a number string as their name. Explicitly assigning a module to a parent module as an attribute will assign the attribute as the child module’s name. Nested modules will have their names split by a dot ..

To create custom composites following more complex patterns, see Writing Custom Composites.

Canonizers

Layerwise relevance propagation (LRP) is not implementation invariant. A good example for this is that for some rules, two consecutive linear layers do not produce the same attribution as a single linear layer with its weight parameter chosen as the product of the two linear layers. The most common case this happens is when models use BatchNorm, which is commonly used directly after, or sometimes before a linear (dense, convolutional) layer. Canonizers are used to avoid this by temporarily enforcing a canonical form of the model. They differ from Rules in that the model is actively changed while the Canonizer is registered, as opposed to using hooks to modify the gradient during runtime.

All available Canonizers can be found in zennit.canonizers. Some of the available basic ones are:

  • SequentialMergeBatchNorm, which traverses the tree of submodules in-order to create a sequence of leave modules, which is then used to detect adjacent linear (dense, convolutional) and BatchNorm modules

  • NamedMergeBatchNorm, which is used to specify explicitly by module name which linear (dense, convolutional) and BatchNorm modules should be merged

  • AttributeCanonizer, which expects a function mapping from module name and type to a dict of attribute names and values which should be changed for applicable modules

  • CompositeCanonizer, which expects a list of canonizers which are then combined to a single canonizer

SequentialMergeBatchNorm traverses the module tree in-order leaves-only using zennit.core.collect_leaves(), and iterates the resulting list to detect adjacent linear (dense, convolutional) and batch-norm modules. The batch-norm’s scale and shift are merged into the adjacent linear layer’s weights and bias.

While not recommended, Canonizers can be used on their own:

import torch
from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten, BatchNorm2d
from zennit.canonizers import SequentialMergeBatchNorm

# setup the model
model = Sequential(
    Conv2d(3, 8, 3, padding=1),
    ReLU(),
    Conv2d(8, 16, 3, padding=1),
    BatchNorm2d(16),
    ReLU(),
    Flatten(),
    Linear(16 * 32 * 32, 1024),
    ReLU(),
    Linear(1024, 10),
)

# create the canonizer
canonizer = SequentialMergeBatchNorm()

# apply the canonizer to the model, which creates multiple canonizer
# instances, one per applicable case
instances = canonizer.apply(model)

# do something with the model
input = torch.randn(1, 3, 32, 32)
output = model(input)

# remove the canonizer instances to revert the model to its original state
for instance in instances:
    instance.remove()

However, the recommended way is to use them with Composites, which will apply and remove canonizer instances automatically while the Composite is active:

from zennit.composites import EpsilonPlusFlat

# create the canonizer
canonizer = SequentialMergeBatchNorm()
# create the composite, with the canonizer as an argument
composite = EpsilonPlusFlat(canonizers=[canonizer])
# create some input data
input = torch.randn(1, 3, 32, 32, requires_grad=True)
# register the composite within the context, which also applies the
# canonizer
with composite.context(model) as modified_model:
    output = modified_model(input)
    # compute the attribution
    attribution, = torch.autograd.grad(output, input, torch.eye(10)[[0]])

# print the absolute sum of the attribution
print(attribution.abs().sum().item())

Be careful not to accidentally save a model’s parameters (e.g. using model.state_dict()) while Canonizers are applied, as this will store the modified state of the model.

Some models implemented in torchvision.models have their own specific Canonizer implemented in zennit.torchvision, which currently are:

  • VGGCanonizer, which applies to torchvision’s implementation of VGG networks and currently is an alias for SequentialMergeBatchNorm

  • ResNetCanonizer, which applies to torchvision’s implementation of ResNet networks, which merges BatchNorms and replaces residual connections with the explicit Sum module, which makes it possible to assign a rule to the residual connection.

More technical detail to implement custom Canonizers may be found in Writing Custom Canonizers.