zennit.core
Core functions and classes
Functions
Generator function to collect all leaf modules of a module. |
|
Context manager to temporarily modify parameter attributes (all by default) of a module. |
|
Stabilize input for safe division. |
Classes
A hook to compute the layerwise attribution of the module it is attached to. |
|
A Composite to apply canonizers and register hooks to modules. |
|
A context object to register a composite in a context and remove the associated hooks and canonizers afterwards. |
|
Base class for hooks to be used to compute layerwise attributions. |
|
Identity to add a grad_fn to a tensor, so a backward hook can be applied. |
|
Create weak reference to call .remove on some instance. |
|
A list to hold handles, with the ability to call remove on all of its members. |
- class zennit.core.BasicHook(input_modifiers=None, param_modifiers=None, output_modifiers=None, gradient_mapper=None, reducer=None, param_keys=None, require_params=True)[source]
Bases:
Hook
A hook to compute the layerwise attribution of the module it is attached to. A BasicHook instance may only be registered with a single module.
- Parameters
input_modifiers (list[callable], optional) – A list of functions to produce multiple inputs. Default is a single input which is the identity.
param_modifiers (list[callable], optional) – A list of functions to temporarily modify the parameters of the attached module for each input produced with input_modifiers. Default is unmodified parameters for each input.
output_modifiers (list[callable], optional) – A list of functions to modify the module’s output computed using the modified parameters before gradient computation for each input produced with input_modifier. Default is the identity for each output.
gradient_mapper (callable, optional) – Function to modify upper relevance. Call signature is of form (grad_output, outputs) and a tuple of the same size as outputs is expected to be returned. outputs has the same size as input_modifiers and param_modifiers. Default is a stabilized normalization by each of the outputs, multiplied with the output gradient.
reducer (callable) – Function to reduce all the inputs and gradients produced through input_modifiers and param_modifiers. Call signature is of form (inputs, gradients), where inputs and gradients have the same as input_modifiers and param_modifiers. Default is the sum of the multiplications of each input and its corresponding gradient.
param_keys (list[str], optional) – A list of parameters that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.
require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.
- backward(module, grad_input, grad_output)[source]
Backward hook to compute LRP based on the class attributes.
- class zennit.core.Composite(module_map=None, canonizers=None)[source]
Bases:
object
A Composite to apply canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time.
- Parameters
module_map (callable) – A function (ctx: dict, name: str, module: torch.nn.Module) -> Hook or None which
canonizers (list[Canonizer]) – List of canonizer instances to be applied before applying hooks.
- context(module)[source]
Return a CompositeContext object with this instance and the supplied module.
- Parameters
module (obj:torch.nn.module) – Module for which to register this composite in the context.
- register(module)[source]
Apply all canonizers and register all hooks to a module (and its recursive children). Previous canonizers of this composite are reverted and all hooks registered by this composite are removed. The module or any of its children (recursively) may still have other hooks attached.
- Parameters
module (obj:torch.nn.Module) – Hooks and canonizers will be applied to this module recursively according to module_map and canonizers
- class zennit.core.CompositeContext(module, composite)[source]
Bases:
object
A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.
- Parameters
module (obj:torch.nn.Module) – The module to which composite should be registered.
composite (obj:Composite) – The composite which shall be registered to module.
- class zennit.core.Hook[source]
Bases:
object
Base class for hooks to be used to compute layerwise attributions.
- copy()[source]
Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.
- post_forward(module, input, output)[source]
Register a backward-hook to the resulting tensor right after the forward.
- class zennit.core.Identity(*args, **kwargs)[source]
Bases:
Function
Identity to add a grad_fn to a tensor, so a backward hook can be applied.
- class zennit.core.RemovableHandle(instance)[source]
Bases:
object
Create weak reference to call .remove on some instance.
- class zennit.core.RemovableHandleList(iterable=(), /)[source]
Bases:
list
A list to hold handles, with the ability to call remove on all of its members.
- zennit.core.collect_leaves(module)[source]
Generator function to collect all leaf modules of a module.
- Parameters
module (obj:torch.nn.Module) – A module for which the leaves will be collected.
- Yields
leaf (obj:torch.nn.Module) – Either a leaf of the module structure, or the module itself if it has no children.
- zennit.core.mod_params(module, modifier, param_keys=None, require_params=True)[source]
Context manager to temporarily modify parameter attributes (all by default) of a module.
- Parameters
module (obj:torch.nn.Module) – Module of which to modify parameters. If requires_params is True, it must have all elements given in param_keys as attributes (attributes are allowed to be None, in which case they are ignored).
modifier (function) – A function used to modify parameter attributes. If param_keys is empty, this is not used.
param_keys (list[str], optional) – A list of parameters that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.
require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.
- Raises
RuntimeError – If require_params is True and module is missing an attribute listed in param_keys.
- Yields
module (obj:torch.nn.Module) – The module with appropriate parameters temporarily modified.
- zennit.core.stabilize(input, epsilon=1e-06)[source]
Stabilize input for safe division.
- Parameters
input (obj:torch.Tensor) – Tensor to stabilize.
epsilon (float, optional) – Value to replace zero elements with.
- Returns
obj – New Tensor copied from input with all zero elements set to epsilon.
- Return type
torch.Tensor