zennit.core

Core functions and classes

Functions

collect_leaves

Generator function to collect all leaf modules of a module.

mod_params

Context manager to temporarily modify parameter attributes (all by default) of a module.

stabilize

Stabilize input for safe division.

Classes

BasicHook

A hook to compute the layerwise attribution of the module it is attached to.

Composite

A Composite to apply canonizers and register hooks to modules.

CompositeContext

A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

Hook

Base class for hooks to be used to compute layerwise attributions.

Identity

Identity to add a grad_fn to a tensor, so a backward hook can be applied.

RemovableHandle

Create weak reference to call .remove on some instance.

RemovableHandleList

A list to hold handles, with the ability to call remove on all of its members.

class zennit.core.BasicHook(input_modifiers=None, param_modifiers=None, output_modifiers=None, gradient_mapper=None, reducer=None, param_keys=None, require_params=True)[source]

Bases: Hook

A hook to compute the layerwise attribution of the module it is attached to. A BasicHook instance may only be registered with a single module.

Parameters
  • input_modifiers (list[callable], optional) – A list of functions to produce multiple inputs. Default is a single input which is the identity.

  • param_modifiers (list[callable], optional) – A list of functions to temporarily modify the parameters of the attached module for each input produced with input_modifiers. Default is unmodified parameters for each input.

  • output_modifiers (list[callable], optional) – A list of functions to modify the module’s output computed using the modified parameters before gradient computation for each input produced with input_modifier. Default is the identity for each output.

  • gradient_mapper (callable, optional) – Function to modify upper relevance. Call signature is of form (grad_output, outputs) and a tuple of the same size as outputs is expected to be returned. outputs has the same size as input_modifiers and param_modifiers. Default is a stabilized normalization by each of the outputs, multiplied with the output gradient.

  • reducer (callable) – Function to reduce all the inputs and gradients produced through input_modifiers and param_modifiers. Call signature is of form (inputs, gradients), where inputs and gradients have the same as input_modifiers and param_modifiers. Default is the sum of the multiplications of each input and its corresponding gradient.

  • param_keys (list[str], optional) – A list of parameters that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.

  • require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.

backward(module, grad_input, grad_output)[source]

Backward hook to compute LRP based on the class attributes.

copy()[source]

Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.

forward(module, input, output)[source]

Forward hook to save module in-/outputs.

class zennit.core.Composite(module_map=None, canonizers=None)[source]

Bases: object

A Composite to apply canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time.

Parameters
  • module_map (callable) – A function (ctx: dict, name: str, module: torch.nn.Module) -> Hook or None which

  • canonizers (list[Canonizer]) – List of canonizer instances to be applied before applying hooks.

context(module)[source]

Return a CompositeContext object with this instance and the supplied module.

Parameters

module (obj:torch.nn.module) – Module for which to register this composite in the context.

register(module)[source]

Apply all canonizers and register all hooks to a module (and its recursive children). Previous canonizers of this composite are reverted and all hooks registered by this composite are removed. The module or any of its children (recursively) may still have other hooks attached.

Parameters

module (obj:torch.nn.Module) – Hooks and canonizers will be applied to this module recursively according to module_map and canonizers

remove()[source]

Remove all handles for hooks and canonizers. Hooks will simply be removed from their corresponding Modules. Canonizers will revert the state of the modules they changed.

class zennit.core.CompositeContext(module, composite)[source]

Bases: object

A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

Parameters
  • module (obj:torch.nn.Module) – The module to which composite should be registered.

  • composite (obj:Composite) – The composite which shall be registered to module.

class zennit.core.Hook[source]

Bases: object

Base class for hooks to be used to compute layerwise attributions.

backward(module, grad_input, grad_output)[source]

Hook applied during backward-pass

copy()[source]

Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.

forward(module, input, output)[source]

Hook applied during forward-pass

post_forward(module, input, output)[source]

Register a backward-hook to the resulting tensor right after the forward.

pre_backward(module, grad_input, grad_output)[source]

Store the grad_output for the backward hook

pre_forward(module, input)[source]

Apply an Identity to the input before the module to register a backward hook.

register(module)[source]

Register this instance by registering all hooks to the supplied module.

remove()[source]

When removing hooks, remove all references to stored tensors

class zennit.core.Identity(*args, **kwargs)[source]

Bases: Function

Identity to add a grad_fn to a tensor, so a backward hook can be applied.

static backward(ctx, *grad_outputs)[source]

Backward identity.

static forward(ctx, *inputs)[source]

Forward identity.

class zennit.core.RemovableHandle(instance)[source]

Bases: object

Create weak reference to call .remove on some instance.

remove()[source]

Call remove on weakly reference instance if it still exists.

class zennit.core.RemovableHandleList(iterable=(), /)[source]

Bases: list

A list to hold handles, with the ability to call remove on all of its members.

remove()[source]

Call remove on all members, effectively removing handles from modules, or reverting canonizers.

zennit.core.collect_leaves(module)[source]

Generator function to collect all leaf modules of a module.

Parameters

module (obj:torch.nn.Module) – A module for which the leaves will be collected.

Yields

leaf (obj:torch.nn.Module) – Either a leaf of the module structure, or the module itself if it has no children.

zennit.core.mod_params(module, modifier, param_keys=None, require_params=True)[source]

Context manager to temporarily modify parameter attributes (all by default) of a module.

Parameters
  • module (obj:torch.nn.Module) – Module of which to modify parameters. If requires_params is True, it must have all elements given in param_keys as attributes (attributes are allowed to be None, in which case they are ignored).

  • modifier (function) – A function used to modify parameter attributes. If param_keys is empty, this is not used.

  • param_keys (list[str], optional) – A list of parameters that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.

  • require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.

Raises

RuntimeError – If require_params is True and module is missing an attribute listed in param_keys.

Yields

module (obj:torch.nn.Module) – The module with appropriate parameters temporarily modified.

zennit.core.stabilize(input, epsilon=1e-06)[source]

Stabilize input for safe division.

Parameters
  • input (obj:torch.Tensor) – Tensor to stabilize.

  • epsilon (float, optional) – Value to replace zero elements with.

Returns

obj – New Tensor copied from input with all zero elements set to epsilon.

Return type

torch.Tensor