zennit.canonizers
Functions to produce a canonical form of models fit for LRP
Classes
Canonizer to set an attribute of module instances. |
|
Canonizer Base class. |
|
A Composite of Canonizers, which applies all supplied canonizers. |
|
Abstract Canonizer to merge the parameters of batch norms into linear modules. |
|
Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names. |
|
Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module. |
- class zennit.canonizers.AttributeCanonizer(attribute_map)[source]
Bases:
Canonizer
Canonizer to set an attribute of module instances. Note that the use of this Canonizer removes previously set attributes after removal.
- Parameters
attribute_map (Function) – A function that returns either None, if not applicable, or a dict with keys describing which attributes to overload for a module. The function signature is (name: string, module: type) -> None or dict.
- apply(root_module)[source]
Overload the attributes for all applicable modules.
- Parameters
root_module (obj:torch.nn.Module) – Root module for which underlying modules will have their attributes overloaded.
- Returns
instances – The applied canonizer instances, which may be removed by calling .remove.
- Return type
list of obj:Canonizer
- class zennit.canonizers.Canonizer[source]
Bases:
object
Canonizer Base class. Canonizers modify modules temporarily such that certain attribution rules can properly be applied.
- class zennit.canonizers.CompositeCanonizer(canonizers)[source]
Bases:
Canonizer
A Composite of Canonizers, which applies all supplied canonizers.
- Parameters
canonizers (list of obj:Canonizer) – Canonizers of which to build a Composite of.
- class zennit.canonizers.MergeBatchNorm[source]
Bases:
Canonizer
Abstract Canonizer to merge the parameters of batch norms into linear modules.
- static merge_batch_norm(modules, batch_norm)[source]
Update parameters of a linear layer to additionally include a Batch Normalization operation and update the batch normalization layer to instead compute the identity.
- Parameters
modules (list of obj:torch.nn.Module) – Linear layers with mandatory attributes weight and bias.
batch_norm (obj:torch.nn.Module) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps
- register(linears, batch_norm)[source]
Store the parameters of the linear modules and the batch norm module and apply the merge.
- Parameters
linear (list of obj:torch.nn.Module) – List of linear layer with mandatory attributes weight and bias.
batch_norm (obj:torch.nn.Module) – Batch Normalization module with mandatory attributes running_mean, running_var, weight, bias and eps
- class zennit.canonizers.NamedMergeBatchNorm(name_map)[source]
Bases:
MergeBatchNorm
Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names.
- Parameters
name_map (list[tuple[string], string]) – List of which linear layer names belong to which batch norm name.
- class zennit.canonizers.SequentialMergeBatchNorm[source]
Bases:
MergeBatchNorm
Canonizer to merge the parameters of all batch norms that appear sequentially right after a linear module.
- apply(root_module)[source]
Finds a batch norm following right after a linear layer, and creates a copy of this instance to merge them by fusing the batch norm parameters into the linear layer and reducing the batch norm to the identity.
- Parameters
root_module (obj:torch.nn.Module) – A module of which the leaves will be searched and if a batch norm is found right after a linear layer, will be merged.
- Returns
instances – A list of instances of this class which modified the appropriate leaves.
- Return type
list