zennit.core
Core functions and classes
Functions
Generator function to collect all leaf modules of a module. |
|
Expand a scalar value or tensor to a shape. |
|
Stabilize input for safe division. |
|
Generator which, given a compressed iterable produced by |
|
Create a function wrapper factory (i.e. a decorator), which takes a single function argument |
Classes
A hook to compute the layer-wise attribution of the module it is attached to. |
|
A Composite to apply canonizers and register hooks to modules. |
|
A context object to register a composite in a context and remove the associated hooks and canonizers afterwards. |
|
Base class for hooks to be used to compute layer-wise attributions. |
|
Identity to add a grad_fn to a tensor, so a backward hook can be applied. |
|
Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module. |
|
Create weak reference to call .remove on some instance. |
|
A list to hold handles, with the ability to call remove on all of its members. |
|
Class to create a stabilizer callable. |
- class zennit.core.Stabilizer[source]
Bases:
objectClass to create a stabilizer callable.
- Parameters:
epsilon (float, optional) – Value by which to shift/clip elements of
input.clip (bool, optional) – If
False(default), addepsilonmultiplied by each entry’s sign (+1 for 0). IfTrue, instead clip the absolute value ofinputand multiply it by each entry’s original sign.norm_scale (bool, optional) – If
False(default),epsilonis added to/used to clipinput. IfTrue, scaleepsilonby the square root of the mean over the squared elements of the specified dimensionsdim.dim (tuple[int], optional) – If
norm_scaleisTrue, specifies the dimension over which the scaled norm should be computed (all except dimension 0 by default).
- classmethod ensure(value)[source]
Given a value, return a stabilizer. If
valueis a float, a Stabilizer with that epsilonvalueis returned. Ifvalueis callable, it will be used directly as a stabilizer. Otherwise a TypeError will be raised.- Parameters:
value (float, int, or callable) – The value used to produce a valid stabilizer function.
- Returns:
A callable to be used as a stabilizer.
- Return type:
callable or Stabilizer
- Raises:
TypeError – If no valid stabilizer could be produced from
value.
- zennit.core.stabilize(input, epsilon=1e-6, clip=False, norm_scale=False, dim=None)[source]
Stabilize input for safe division.
- Parameters:
input (
torch.Tensor) – Tensor to stabilize.epsilon (float, optional) – Value by which to shift/clip elements of
input.clip (bool, optional) – If
False(default), addepsilonmultiplied by each entry’s sign (+1 for 0). IfTrue, instead clip the absolute value ofinputand multiply it by each entry’s original sign.norm_scale (bool, optional) – If
False(default),epsilonis added to/used to clipinput. IfTrue, scaleepsilonby the square root of the mean over the squared elements of the specified dimensionsdim.dim (tuple[int], optional) – If
norm_scaleisTrue, specifies the dimension over which the scaled norm should be computed. Defaults to all except dimension 0.
- Returns:
New Tensor copied from input with values shifted by epsilon.
- Return type:
torch.Tensor
- zennit.core.expand(tensor, shape, cut_batch_dim=False)[source]
Expand a scalar value or tensor to a shape. In addition to torch.Tensor.expand, this will also accept non-torch.tensor objects, which will be used to create a new tensor. If
tensorhas fewer dimensions thanshape, singleton dimension will be appended to match the size ofshapebefore expanding.- Parameters:
tensor (int, float or
torch.Tensor) – Scalar or tensor to expand to the size ofshape.shape (tuple[int]) – Shape to which
tensorwill be expanded.cut_batch_dim (bool, optional) – If True, take only the first
shape[0]entries along dimension 0 of the expandedtensor, if it has more entries in dimension 0 thanshape. Default (False) is not to cut, which will instead cause aRuntimeErrordue to the size mismatch.
- Returns:
A new tensor expanded from
tensorwith shapeshape.- Return type:
torch.Tensor- Raises:
RuntimeError – If
tensorcould not be expanded toshapedue to incompatible shapes.
- zennit.core.zero_wrap(zero_params)[source]
Create a function wrapper factory (i.e. a decorator), which takes a single function argument
(name, param) -> tensorsuch that the function is only called if name is not equal to zero_params, if zero_params is a string, or it is not in zero_params. Otherwise return torch.zeros_like of that tensor.- Parameters:
zero_params (str or list[str]) – String or list of strings compared to name.
- Returns:
The function wrapper to be called on the function.
- Return type:
function
- zennit.core.uncompress(data, selector, compressed) Generator[source]
Generator which, given a compressed iterable produced by
itertools.compressand (some iterable similar to) the original data and selector used forcompress, yields values from compressed or data depending on selector. True values in selector skip data one ahead and yield a value from compressed, while False values yield one value from data.- Parameters:
data (iterable) – The iterable (similar to the) original data. False values in the selector will be filled with values from this iterator, while True values will cause this iterable to be skipped.
selector (iterable of bool) – The original selector used to produce compressed. Chooses whether elements from data or from compressed will be yielded.
compressed (iterable) – The results of
itertools.compress. Will be yielded for each True element in selector.
- Yields:
object – An element of data if the associated element of selector is False, otherwise an element of compressed while skipping data one ahead.
- Return type:
Generator
- class zennit.core.ParamMod[source]
Bases:
objectClass to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
- Parameters:
modifier (function) – A function used to modify parameter attributes. If param_keys is empty, this is not used.
param_keys (list[str], optional) – A list of parameter names that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.
zero_params (list[str], optional) – A list of parameter names that shall set to zero. If None (default), no parameters are set to zero.
require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.
- state_dicts(module)[source]
Returns a state_dict of the modified module parameters.
- Parameters:
module (
torch.nn.Module) – The module for which parameters shall be modified.- Returns:
original_state (dict of
torch.Tensor) – The original, unmodified parameters.modified_state (dict of
torch.Tensor) – The modified parameters.
- Raises:
RuntimeError – If parameters are missing and self.require_params has been set to
True.
- zennit.core.collect_leaves(module) Iterator[Module][source]
Generator function to collect all leaf modules of a module.
- Parameters:
module (
torch.nn.Module) – A module for which the leaves will be collected.- Yields:
leaf (
torch.nn.Module) – Either a leaf of the module structure, or the module itself if it has no children.- Return type:
Iterator[Module]
- class zennit.core.Identity[source]
Bases:
FunctionIdentity to add a grad_fn to a tensor, so a backward hook can be applied.
- class zennit.core.Hook[source]
Bases:
objectBase class for hooks to be used to compute layer-wise attributions.
- pre_forward(module, args, kwargs)[source]
Apply an Identity to the input before the module to register a backward hook.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.args (tuple of
torch.Tensor) – The input tensors passed tomodule.forward.kwargs (dict) – The keyword arguments passed to
module.forward.
- Returns:
A tuple of the modified input tensors.
- Return type:
tuple of
torch.Tensor, optional
- post_forward(module, args, kwargs, output)[source]
Register a backward-hook to the resulting tensor right after the forward.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.args (tuple of
torch.Tensor) – The input tensors passed tomodule.forward.kwargs (tuple of object) – The keyword arguments passed to
module.forward.output (
torch.Tensor) – The output tensor.
- Returns:
A tuple of the modified output tensors.
- Return type:
tuple of
torch.Tensor, optional
- pre_backward(module, grad_input, grad_output)[source]
Store the grad_output for the backward hook.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.grad_input (
torch.Tensor) – The input gradient tensor.grad_output (
torch.Tensor) – The output gradient tensor.
- forward(module, args, kwargs, output)[source]
Hook applied during forward-pass.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.args (tuple of
torch.Tensor) – The input tensors passed tomodule.forward.kwargs (tuple of object) – The keyword arguments passed to
module.forward.output (
torch.Tensor) – The output tensor.
- backward(module, grad_input, grad_output)[source]
Hook applied during backward-pass.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.grad_input (
torch.Tensor) – The input gradient tensor.grad_output (
torch.Tensor) – The output gradient tensor.
- copy()[source]
Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.
- Returns:
A copy of this hook.
- Return type:
- class zennit.core.BasicHook[source]
Bases:
HookA hook to compute the layer-wise attribution of the module it is attached to. A BasicHook instance may only be registered with a single module.
- Parameters:
input_modifiers (list[callable], optional) – A list of functions
(input: torch.Tensor) -> torch.Tensorto produce multiple inputs. Default is a single input which is the identity.param_modifiers (list[
ParamModor callable], optional) – A list of ParamMod instances or functions(obj: torch.Tensor, name: str) -> torch.Tensor, with parameter tensorobj, registered in the root model asname, to temporarily modify the parameters of the attached module for each input produced with input_modifiers. Default is unmodified parameters for each input. Use aParamModinstance to specify which parameters should be modified, whether they are required, and which should be set to zero.output_modifiers (list[callable], optional) – A list of functions
(input: torch.Tensor) -> torch.Tensorto modify the module’s output computed using the modified parameters before gradient computation for each input produced with input_modifier. Default is the identity for each output.gradient_mapper (callable, optional) – Function
(out_grad: torch.Tensor, outputs: list[torch.Tensor]) -> list[torch.Tensor]to modify upper relevance. A list or tuple of the same size asoutputsis expected to be returned.outputshas the same size asinput_modifiersandparam_modifiers. Default is a stabilized normalization by each of the outputs, multiplied with the output gradient.reducer (callable) – Function
(inputs: list[torch.Tensor], gradients: list[torch.Tensor]) -> torch.Tensorto reduce all the inputs and gradients produced throughinput_modifiersandparam_modifiers.inputsandgradientshave the same asinput_modifiersandparam_modifiers. Default is the sum of the multiplications of each input and its corresponding gradient.stabilizer (callable or float, optional) – Stabilization parameter for rules other than
Epsilon. Ifstabilizeris a float, it will be added to the denominator with the same sign as each respective entry. If it is callable, a function(input: torch.Tensor) -> torch.Tensoris expected, of which the output corresponds to the stabilized denominator.
- __init__(input_modifiers=None, param_modifiers=None, output_modifiers=None, gradient_mapper=None, reducer=None, stabilizer=1e-6)[source]
- forward(module, args, kwargs, output)[source]
Forward hook to save module in-/outputs.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.args (tuple of
torch.Tensor) – The input tensors passed tomodule.forward.kwargs (tuple of object) – The keyword arguments passed to
module.forward.output (
torch.Tensor) – The output tensor.
- backward(module, grad_input, grad_output)[source]
Backward hook to compute LRP based on the class attributes.
- Parameters:
module (
torch.nn.Module) – The module to which this hook is attached.grad_input (
torch.Tensor) – The input gradient tensor.grad_output (
torch.Tensor) – The output gradient tensor.
- Returns:
The modified input gradient tensors.
- Return type:
tuple of
torch.nn.Module
- class zennit.core.RemovableHandle[source]
Bases:
objectCreate weak reference to call .remove on some instance.
- Parameters:
instance (object) – The instance to which to create the reference.
- class zennit.core.RemovableHandleList[source]
Bases:
listA list to hold handles, with the ability to call remove on all of its members.
- class zennit.core.CompositeContext[source]
Bases:
objectA context object to register a composite in a context and remove the associated hooks and canonizers afterwards.
- Parameters:
module (
torch.nn.Module) – The module to which composite should be registered.composite (
zennit.core.Composite) – The composite which shall be registered to module.
- class zennit.core.Composite[source]
Bases:
objectA Composite to apply canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time.
- Parameters:
module_map (callable, optional) – A function
(ctx: dict, name: str, module: torch.nn.Module) -> Hook or Nonewhich maps a context, name and module to a matchingHook, orNoneif there is no matchinHook.canonizers (list[
zennit.canonizers.Canonizer], optional) – List of canonizer instances to be applied before applying hooks.
- register(module)[source]
Apply all canonizers and register all hooks to a module (and its recursive children). Previous canonizers of this composite are reverted and all hooks registered by this composite are removed. The module or any of its children (recursively) may still have other hooks attached.
- Parameters:
module (
torch.nn.Module) – Hooks and canonizers will be applied to this module recursively according tomodule_mapandcanonizers.
- remove()[source]
Remove all handles for hooks and canonizers. Hooks will simply be removed from their corresponding Modules. Canonizers will revert the state of the modules they changed.
- context(module)[source]
Return a CompositeContext object with this instance and the supplied module.
- Parameters:
module (
torch.nn.Module) – Module for which to register this composite in the context.- Returns:
A context object which registers the composite to
moduleon entering, and removes it on exiting.- Return type: