Writing Custom Attributors

Attributors provide an additional layer of abstraction over the context of Composites, and are used to directly produce attributions, which may or may not be computed with modified gradients, if they are used, from Composites. More information on Attributors, examples and their use can be found in Using Attributors.

Attributors can be used to implement non-layer-wise or only partly layer-wise attribution methods. For this, it is enough to define a subclass of zennit.attribution.Attributor and implement its forward() and optionally its __init__() methods.

forward() takes 2 arguments, the tensor with respect to which the attribution shall be computed input, and attr_output_fn, which is a function that, given the output of the attributed model, computes the gradient output for the gradient computation, which is, for example, a one-hot encoding of the target label of the attributed input. When calling an Attributor, the __call__ function will ensure forward receives a valid function to transform the output of the analyzed model to a tensor which can be used for the grad_output argument of torch.autograd.grad(). A constant tensor or function is provided by the user either to __init__ or to __call__. It is expected that forward() will return a tuple containing, in order, the model output and the attribution.

As an example, we can implement gradient times input in the following way:

import torch
from torchvision.models import vgg11

from zennit.attribution import Attributor


class GradientTimesInput(Attributor):
    '''Model-agnostic gradient times input.'''
    def forward(self, input, attr_output_fn):
        '''Compute gradient times input.'''
        input_detached = input.detach().requires_grad_(True)
        output = self.model(input_detached)
        gradient, = torch.autograd.grad(
            (output,), (input_detached,), (attr_output_fn(output.detach()),)
        )
        relevance = gradient * input
        return output, relevance

model = vgg11()
data = torch.randn((1, 3, 224, 224))

with GradientTimesInput(model) as attributor:
    output, relevance = attributor(data)

Attributor accepts an optional Composite, which, if supplied, will always be used to create a context in __call__ around forward. For the GradientTimesInput class above, using a Composite will probably not produce anything useful, although more involved combinations of custom Rules and a custom Attributor can be used to implement complex attribution methods with both model-agnostic and layer-wise parts.

The following shows an example of sensitivity analysis, which is the absolute value, with a custom __init__() where we can pass the argument sum_channels to specify whether the Attributor should sum over the channel dimension:

import torch
from torchvision.models import vgg11

from zennit.attribution import Attributor


class SensitivityAnalysis(Attributor):
    '''Model-agnostic sensitivity analysis which optionally sums over color
    channels.
    '''
    def __init__(
        self, model, sum_channels=False, composite=None, attr_output=None
    ):
        super().__init__(
            model, composite=composite, attr_output=attr_output
        )

        self.sum_channels = sum_channels


    def forward(self, input, attr_output_fn):
        '''Compute the absolute gradient (or the sensitivity) and
        optionally sum over the color channels.
        '''
        input_detached = input.detach().requires_grad_(True)
        output = self.model(input_detached)
        gradient, = torch.autograd.grad(
            (output,), (input_detached,), (attr_output_fn(output.detach()),)
        )
        relevance = gradient.abs()
        if self.sum_channels:
            relevance = relevance.sum(1)
        return output, relevance

model = vgg11()
data = torch.randn((1, 3, 224, 224))

with SensitivityAnalysis(model, sum_channels=True) as attributor:
    output, relevance = attributor(data)