zennit.canonizers
Functions to produce a canonical form of models fit for LRP
Classes
Canonizer to set an attribute of module instances. |
|
Canonizer Base class. |
|
A Composite of Canonizers, which applies all supplied canonizers. |
|
Abstract Canonizer to merge the parameters of batch norms into linear modules. |
|
Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names. |
|
Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module. |
- class zennit.canonizers.AttributeCanonizer(attribute_map)[source]
Bases:
Canonizer
Canonizer to set an attribute of module instances. Note that the use of this Canonizer removes previously set attributes after removal.
- Parameters:
attribute_map (Function) – A function that returns either None, if not applicable, or a dict with keys describing which attributes to overload for a module. The function signature is (name: string, module: type) -> None or dict.
- apply(root_module)[source]
Overload the attributes for all applicable modules.
- Parameters:
root_module (
torch.nn.Module
) – Root module for which underlying modules will have their attributes overloaded.- Returns:
instances – The applied canonizer instances, which may be removed by calling .remove.
- Return type:
list of
Canonizer
- class zennit.canonizers.Canonizer[source]
Bases:
object
Canonizer Base class. Canonizers modify modules temporarily such that certain attribution rules can properly be applied.
- class zennit.canonizers.CompositeCanonizer(canonizers)[source]
Bases:
Canonizer
A Composite of Canonizers, which applies all supplied canonizers.
- Parameters:
canonizers (list of
Canonizer
) – Canonizers of which to build a Composite of.
- class zennit.canonizers.MergeBatchNorm[source]
Bases:
Canonizer
Abstract Canonizer to merge the parameters of batch norms into linear modules.
- static merge_batch_norm(modules, batch_norm)[source]
Update parameters of a linear layer to additionally include a Batch Normalization operation and update the batch normalization layer to instead compute the identity.
- Parameters:
modules (list of
torch.nn.Module
) – Linear layers with mandatory attributes weight and bias.batch_norm (
torch.nn.Module
) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps
- register(linears, batch_norm)[source]
Store the parameters of the linear modules and the batch norm module and apply the merge.
- Parameters:
linear (list of
torch.nn.Module
) – List of linear layer with mandatory attributes weight and bias.batch_norm (
torch.nn.Module
) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps
- class zennit.canonizers.NamedMergeBatchNorm(name_map)[source]
Bases:
MergeBatchNorm
Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names.
- Parameters:
name_map (list[tuple[string], string]) – List of which linear layer names belong to which batch norm name.
- class zennit.canonizers.SequentialMergeBatchNorm[source]
Bases:
MergeBatchNorm
Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module.
Note
SequentialMergeBatchNorm traverses the tree of children of the provided module depth-first and in-order. This means that child-modules must be assigned to their parent module in the order they are visited in the forward pass to correctly identify adjacent modules. This also means that activation functions must be assigned in their module-form as a child to their parent-module to properly detect when there is an activation function between linear and batch-norm modules.
- apply(root_module)[source]
Finds a batch norm following right after a linear layer, and creates a copy of this instance to merge them by fusing the batch norm parameters into the linear layer and reducing the batch norm to the identity.
- Parameters:
root_module (
torch.nn.Module
) – A module of which the leaves will be searched and if a batch norm is found right after a linear layer, will be merged.- Returns:
instances – A list of instances of this class which modified the appropriate leaves.
- Return type:
list