zennit.canonizers
Functions to produce a canonical form of models fit for LRP
Classes
Canonizer to set an attribute of module instances. |
|
Canonizer Base class. |
|
A Composite of Canonizers, which applies all supplied canonizers. |
|
Abstract Canonizer to merge the parameters of batch norms into linear modules. |
|
Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names. |
|
Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module. |
- class zennit.canonizers.Canonizer[source]
Bases:
objectCanonizer Base class. Canonizers modify modules temporarily such that certain attribution rules can properly be applied.
- class zennit.canonizers.MergeBatchNorm[source]
Bases:
CanonizerAbstract Canonizer to merge the parameters of batch norms into linear modules.
- register(linears, batch_norm)[source]
Store the parameters of the linear modules and the batch norm module and apply the merge.
- Parameters:
linears (list of
torch.nn.Module) – List of linear layer with mandatory attributes weight and bias.batch_norm (
torch.nn.Module) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps
- remove()[source]
Undo the merge by reverting the parameters of both the linear and the batch norm modules to the state before the merge.
- static merge_batch_norm(modules, batch_norm)[source]
Update parameters of a linear layer to additionally include a Batch Normalization operation and update the batch normalization layer to instead compute the identity.
- Parameters:
modules (list of
torch.nn.Module) – Linear layers with mandatory attributes weight and bias.batch_norm (
torch.nn.Module) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps
- class zennit.canonizers.SequentialMergeBatchNorm[source]
Bases:
MergeBatchNormCanonizer to merge the parameters of all batch norms that appear sequentially right after a linear module.
Note
SequentialMergeBatchNorm traverses the tree of children of the provided module depth-first and in-order. This means that child-modules must be assigned to their parent module in the order they are visited in the forward pass to correctly identify adjacent modules. This also means that activation functions must be assigned in their module-form as a child to their parent-module to properly detect when there is an activation function between linear and batch-norm modules.
- apply(root_module)[source]
Finds a batch norm following right after a linear layer, and creates a copy of this instance to merge them by fusing the batch norm parameters into the linear layer and reducing the batch norm to the identity.
- Parameters:
root_module (
torch.nn.Module) – A module of which the leaves will be searched and if a batch norm is found right after a linear layer, will be merged.- Returns:
instances – A list of instances of this class which modified the appropriate leaves.
- Return type:
list
- class zennit.canonizers.NamedMergeBatchNorm[source]
Bases:
MergeBatchNormCanonizer to merge the parameters of all batch norms into linear modules, specified by their respective names.
- Parameters:
name_map (list[tuple[string], string]) – List of which linear layer names belong to which batch norm name.
- class zennit.canonizers.AttributeCanonizer[source]
Bases:
CanonizerCanonizer to set an attribute of module instances. Note that the use of this Canonizer removes previously set attributes after removal.
- Parameters:
attribute_map (Function) – A function that returns either None, if not applicable, or a dict with keys describing which attributes to overload for a module. The function signature is (name: string, module: type) -> None or dict.
- apply(root_module)[source]
Overload the attributes for all applicable modules.
- Parameters:
root_module (
torch.nn.Module) – Root module for which underlying modules will have their attributes overloaded.- Returns:
instances – The applied canonizer instances, which may be removed by calling .remove.
- Return type:
list of
Canonizer
- register(module, attributes)[source]
Overload the module’s attributes.
- Parameters:
module (
torch.nn.Module) – The module of which the attributes will be overloaded.attributes (dict) – The attributes which to overload for the module.
- class zennit.canonizers.CompositeCanonizer[source]
Bases:
CanonizerA Composite of Canonizers, which applies all supplied canonizers.
- Parameters:
canonizers (list of
Canonizer) – Canonizers of which to build a Composite of.