zennit.canonizers

Functions to produce a canonical form of models fit for LRP

Classes

AttributeCanonizer

Canonizer to set an attribute of module instances.

Canonizer

Canonizer Base class.

CompositeCanonizer

A Composite of Canonizers, which applies all supplied canonizers.

MergeBatchNorm

Abstract Canonizer to merge the parameters of batch norms into linear modules.

NamedMergeBatchNorm

Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names.

SequentialMergeBatchNorm

Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module.

class zennit.canonizers.AttributeCanonizer(attribute_map)[source]

Bases: Canonizer

Canonizer to set an attribute of module instances. Note that the use of this Canonizer removes previously set attributes after removal.

Parameters:

attribute_map (Function) – A function that returns either None, if not applicable, or a dict with keys describing which attributes to overload for a module. The function signature is (name: string, module: type) -> None or dict.

apply(root_module)[source]

Overload the attributes for all applicable modules.

Parameters:

root_module (torch.nn.Module) – Root module for which underlying modules will have their attributes overloaded.

Returns:

instances – The applied canonizer instances, which may be removed by calling .remove.

Return type:

list of Canonizer

copy()[source]

Copy this Canonizer.

Returns:

A copy of this Canonizer.

Return type:

Canonizer

register(module, attributes)[source]

Overload the module’s attributes.

Parameters:
  • module (torch.nn.Module) – The module of which the attributes will be overloaded.

  • attributes (dict) – The attributes which to overload for the module.

remove()[source]

Remove the overloaded attributes. Note that functions are descriptors, and therefore not direct attributes of instance, which is why deleting instance attributes with the same name reverts them to the original function.

class zennit.canonizers.Canonizer[source]

Bases: object

Canonizer Base class. Canonizers modify modules temporarily such that certain attribution rules can properly be applied.

abstract apply(root_module)[source]

Apply this canonizer recursively on all applicable modules.

Parameters:

root_module (torch.nn.Module) – Root module to which to apply the canonizers.

Returns:

A list of all applied instances of this class.

Return type:

list

copy()[source]

Return a copy of this instance.

abstract register()[source]

Apply the changes of this canonizer.

abstract remove()[source]

Revert the changes introduces by this canonizer.

class zennit.canonizers.CompositeCanonizer(canonizers)[source]

Bases: Canonizer

A Composite of Canonizers, which applies all supplied canonizers.

Parameters:

canonizers (list of Canonizer) – Canonizers of which to build a Composite of.

apply(root_module)[source]

Apply call canonizers.

Parameters:

root_module (torch.nn.Module) – Root module for which underlying modules will have canonizers applied.

Returns:

instances – The applied canonizer instances, which may be removed by calling .remove.

Return type:

list of Canonizer

register()[source]

Register this Canonizer. Nothing to do for a CompositeCanonizer.

remove()[source]

Remove this Canonizer. Nothing to do for a CompositeCanonizer.

class zennit.canonizers.MergeBatchNorm[source]

Bases: Canonizer

Abstract Canonizer to merge the parameters of batch norms into linear modules.

static merge_batch_norm(modules, batch_norm)[source]

Update parameters of a linear layer to additionally include a Batch Normalization operation and update the batch normalization layer to instead compute the identity.

Parameters:
  • modules (list of torch.nn.Module) – Linear layers with mandatory attributes weight and bias.

  • batch_norm (torch.nn.Module) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps

register(linears, batch_norm)[source]

Store the parameters of the linear modules and the batch norm module and apply the merge.

Parameters:
  • linear (list of torch.nn.Module) – List of linear layer with mandatory attributes weight and bias.

  • batch_norm (torch.nn.Module) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps

remove()[source]

Undo the merge by reverting the parameters of both the linear and the batch norm modules to the state before the merge.

class zennit.canonizers.NamedMergeBatchNorm(name_map)[source]

Bases: MergeBatchNorm

Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names.

Parameters:

name_map (list[tuple[string], string]) – List of which linear layer names belong to which batch norm name.

apply(root_module)[source]

Create appropriate merges given by the name map.

Parameters:

root_module (torch.nn.Module) – Root module for which underlying modules will be merged.

Returns:

instances – A list of merge instances.

Return type:

list

copy()[source]

Return a copy of this instance.

class zennit.canonizers.SequentialMergeBatchNorm[source]

Bases: MergeBatchNorm

Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module.

Note

SequentialMergeBatchNorm traverses the tree of children of the provided module depth-first and in-order. This means that child-modules must be assigned to their parent module in the order they are visited in the forward pass to correctly identify adjacent modules. This also means that activation functions must be assigned in their module-form as a child to their parent-module to properly detect when there is an activation function between linear and batch-norm modules.

apply(root_module)[source]

Finds a batch norm following right after a linear layer, and creates a copy of this instance to merge them by fusing the batch norm parameters into the linear layer and reducing the batch norm to the identity.

Parameters:

root_module (torch.nn.Module) – A module of which the leaves will be searched and if a batch norm is found right after a linear layer, will be merged.

Returns:

instances – A list of instances of this class which modified the appropriate leaves.

Return type:

list