zennit.core
Core functions and classes
Functions
Generator function to collect all leaf modules of a module. |
|
Expand a scalar value or tensor to a shape. |
|
Stabilize input for safe division. |
|
Create a function wrapper factory (i.e. |
Classes
A hook to compute the layer-wise attribution of the module it is attached to. |
|
A Composite to apply canonizers and register hooks to modules. |
|
A context object to register a composite in a context and remove the associated hooks and canonizers afterwards. |
|
Base class for hooks to be used to compute layer-wise attributions. |
|
Identity to add a grad_fn to a tensor, so a backward hook can be applied. |
|
Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module. |
|
Create weak reference to call .remove on some instance. |
|
A list to hold handles, with the ability to call remove on all of its members. |
|
Class to create a stabilizer callable. |
- class zennit.core.BasicHook(input_modifiers=None, param_modifiers=None, output_modifiers=None, gradient_mapper=None, reducer=None, stabilizer=1e-06)[source]
Bases:
Hook
A hook to compute the layer-wise attribution of the module it is attached to. A BasicHook instance may only be registered with a single module.
- Parameters:
input_modifiers (list[callable], optional) – A list of functions
(input: torch.Tensor) -> torch.Tensor
to produce multiple inputs. Default is a single input which is the identity.param_modifiers (list[
ParamMod
or callable], optional) – A list of ParamMod instances or functions(obj: torch.Tensor, name: str) -> torch.Tensor
, with parameter tensorobj
, registered in the root model asname
, to temporarily modify the parameters of the attached module for each input produced with input_modifiers. Default is unmodified parameters for each input. Use aParamMod
instance to specify which parameters should be modified, whether they are required, and which should be set to zero.output_modifiers (list[callable], optional) – A list of functions
(input: torch.Tensor) -> torch.Tensor
to modify the module’s output computed using the modified parameters before gradient computation for each input produced with input_modifier. Default is the identity for each output.gradient_mapper (callable, optional) – Function
(out_grad: torch.Tensor, outputs: list[torch.Tensor]) -> list[torch.Tensor]
to modify upper relevance. A list or tuple of the same size asoutputs
is expected to be returned.outputs
has the same size asinput_modifiers
andparam_modifiers
. Default is a stabilized normalization by each of the outputs, multiplied with the output gradient.reducer (callable) – Function
(inputs: list[torch.Tensor], gradients: list[torch.Tensor]) -> torch.Tensor
to reduce all the inputs and gradients produced throughinput_modifiers
andparam_modifiers
.inputs
andgradients
have the same asinput_modifiers
andparam_modifiers
. Default is the sum of the multiplications of each input and its corresponding gradient.
- backward(module, grad_input, grad_output)[source]
Backward hook to compute LRP based on the class attributes.
- class zennit.core.Composite(module_map=None, canonizers=None)[source]
Bases:
object
A Composite to apply canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time.
- Parameters:
module_map (callable, optional) – A function
(ctx: dict, name: str, module: torch.nn.Module) -> Hook or None
which maps a context, name and module to a matchingHook
, orNone
if there is no matchinHook
.canonizers (list[
zennit.canonizers.Canonizer
], optional) – List of canonizer instances to be applied before applying hooks.
- context(module)[source]
Return a CompositeContext object with this instance and the supplied module.
- Parameters:
module (
torch.nn.Module
) – Module for which to register this composite in the context.- Returns:
A context object which registers the composite to
module
on entering, and removes it on exiting.- Return type:
- inactive()[source]
Context manager to temporarily deactivate the gradient modification. This can be used to compute the gradient of the modified gradient.
- register(module)[source]
Apply all canonizers and register all hooks to a module (and its recursive children). Previous canonizers of this composite are reverted and all hooks registered by this composite are removed. The module or any of its children (recursively) may still have other hooks attached.
- Parameters:
module (
torch.nn.Module
) – Hooks and canonizers will be applied to this module recursively according tomodule_map
andcanonizers
.
- class zennit.core.CompositeContext(module, composite)[source]
Bases:
object
A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.
- Parameters:
module (
torch.nn.Module
) – The module to which composite should be registered.composite (
zennit.core.Composite
) – The composite which shall be registered to module.
- class zennit.core.Hook[source]
Bases:
object
Base class for hooks to be used to compute layer-wise attributions.
- copy()[source]
Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.
- post_forward(module, input, output)[source]
Register a backward-hook to the resulting tensor right after the forward.
- class zennit.core.Identity(*args, **kwargs)[source]
Bases:
Function
Identity to add a grad_fn to a tensor, so a backward hook can be applied.
- class zennit.core.ParamMod(modifier, param_keys=None, zero_params=None, require_params=True)[source]
Bases:
object
Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
- Parameters:
modifier (function) – A function used to modify parameter attributes. If param_keys is empty, this is not used.
param_keys (list[str], optional) – A list of parameter names that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.
zero_params (list[str], optional) – A list of parameter names that shall set to zero. If None (default), no parameters are set to zero.
require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.
- class zennit.core.RemovableHandle(instance)[source]
Bases:
object
Create weak reference to call .remove on some instance.
- class zennit.core.RemovableHandleList(iterable=(), /)[source]
Bases:
list
A list to hold handles, with the ability to call remove on all of its members.
- class zennit.core.Stabilizer(epsilon=1e-06, clip=False, norm_scale=False, dim=None)[source]
Bases:
object
Class to create a stabilizer callable.
- Parameters:
epsilon (float, optional) – Value by which to shift/clip elements of
input
.clip (bool, optional) – If
False
(default), addepsilon
multiplied by each entry’s sign (+1 for 0). IfTrue
, instead clip the absolute value ofinput
and multiply it by each entry’s original sign.norm_scale (bool, optional) – If
False
(default),epsilon
is added to/used to clipinput
. IfTrue
, scaleepsilon
by the square root of the mean over the squared elements of the specified dimensionsdim
.dim (tuple[int], optional) – If
norm_scale
isTrue
, specifies the dimension over which the scaled norm should be computed (all except dimension 0 by default).
- classmethod ensure(value)[source]
Given a value, return a stabilizer. If
value
is a float, a Stabilizer with that epsilonvalue
is returned. Ifvalue
is callable, it will be used directly as a stabilizer. Otherwise a TypeError will be raised.- Parameters:
value (float, int, or callable) – The value used to produce a valid stabilizer function.
- Returns:
A callable to be used as a stabilizer.
- Return type:
callable or Stabilizer
- Raises:
TypeError – If no valid stabilizer could be produced from
value
.
- zennit.core.collect_leaves(module)[source]
Generator function to collect all leaf modules of a module.
- Parameters:
module (
torch.nn.Module
) – A module for which the leaves will be collected.- Yields:
leaf (
torch.nn.Module
) – Either a leaf of the module structure, or the module itself if it has no children.
- zennit.core.expand(tensor, shape, cut_batch_dim=False)[source]
Expand a scalar value or tensor to a shape. In addition to torch.Tensor.expand, this will also accept non-torch.tensor objects, which will be used to create a new tensor. If
tensor
has fewer dimensions thanshape
, singleton dimension will be appended to match the size ofshape
before expanding.- Parameters:
tensor (int, float or
torch.Tensor
) – Scalar or tensor to expand to the size ofshape
.shape (tuple[int]) – Shape to which
tensor
will be expanded.cut_batch_dim (bool, optional) – If True, take only the first
shape[0]
entries along dimension 0 of the expandedtensor
, if it has more entries in dimension 0 thanshape
. Default (False) is not to cut, which will instead cause aRuntimeError
due to the size mismatch.
- Returns:
A new tensor expanded from
tensor
with shapeshape
.- Return type:
torch.Tensor
- Raises:
RuntimeError – If
tensor
could not be expanded toshape
due to incompatible shapes.
- zennit.core.stabilize(input, epsilon=1e-06, clip=False, norm_scale=False, dim=None)[source]
Stabilize input for safe division.
- Parameters:
input (
torch.Tensor
) – Tensor to stabilize.epsilon (float, optional) – Value by which to shift/clip elements of
input
.clip (bool, optional) – If
False
(default), addepsilon
multiplied by each entry’s sign (+1 for 0). IfTrue
, instead clip the absolute value ofinput
and multiply it by each entry’s original sign.norm_scale (bool, optional) – If
False
(default),epsilon
is added to/used to clipinput
. IfTrue
, scaleepsilon
by the square root of the mean over the squared elements of the specified dimensionsdim
.dim (tuple[int], optional) – If
norm_scale
isTrue
, specifies the dimension over which the scaled norm should be computed. Defaults to all except dimension 0.
- Returns:
New Tensor copied from input with values shifted by epsilon.
- Return type:
torch.Tensor
- zennit.core.zero_wrap(zero_params)[source]
Create a function wrapper factory (i.e. a decorator), which takes a single function argument
(name, param) -> tensor
such that the function is only called if name is not equal to zero_params, if zero_params is a string, or it is not in zero_params. Otherwise return torch.zeros_like of that tensor.- Parameters:
zero_params (str or list[str]) – String or list of strings compared to name.
- Returns:
The function wrapper to be called on the function.
- Return type:
function