zennit.core

Core functions and classes

Functions

collect_leaves

Generator function to collect all leaf modules of a module.

expand

Expand a scalar value or tensor to a shape.

stabilize

Stabilize input for safe division.

zero_wrap

Create a function wrapper factory (i.e.

Classes

BasicHook

A hook to compute the layer-wise attribution of the module it is attached to.

Composite

A Composite to apply canonizers and register hooks to modules.

CompositeContext

A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

Hook

Base class for hooks to be used to compute layer-wise attributions.

Identity

Identity to add a grad_fn to a tensor, so a backward hook can be applied.

ParamMod

Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.

RemovableHandle

Create weak reference to call .remove on some instance.

RemovableHandleList

A list to hold handles, with the ability to call remove on all of its members.

Stabilizer

Class to create a stabilizer callable.

class zennit.core.BasicHook(input_modifiers=None, param_modifiers=None, output_modifiers=None, gradient_mapper=None, reducer=None, stabilizer=1e-06)[source]

Bases: Hook

A hook to compute the layer-wise attribution of the module it is attached to. A BasicHook instance may only be registered with a single module.

Parameters:
  • input_modifiers (list[callable], optional) – A list of functions (input: torch.Tensor) -> torch.Tensor to produce multiple inputs. Default is a single input which is the identity.

  • param_modifiers (list[ParamMod or callable], optional) – A list of ParamMod instances or functions (obj: torch.Tensor, name: str) -> torch.Tensor, with parameter tensor obj, registered in the root model as name, to temporarily modify the parameters of the attached module for each input produced with input_modifiers. Default is unmodified parameters for each input. Use a ParamMod instance to specify which parameters should be modified, whether they are required, and which should be set to zero.

  • output_modifiers (list[callable], optional) – A list of functions (input: torch.Tensor) -> torch.Tensor to modify the module’s output computed using the modified parameters before gradient computation for each input produced with input_modifier. Default is the identity for each output.

  • gradient_mapper (callable, optional) – Function (out_grad: torch.Tensor, outputs: list[torch.Tensor]) -> list[torch.Tensor] to modify upper relevance. A list or tuple of the same size as outputs is expected to be returned. outputs has the same size as input_modifiers and param_modifiers. Default is a stabilized normalization by each of the outputs, multiplied with the output gradient.

  • reducer (callable) – Function (inputs: list[torch.Tensor], gradients: list[torch.Tensor]) -> torch.Tensor to reduce all the inputs and gradients produced through input_modifiers and param_modifiers. inputs and gradients have the same as input_modifiers and param_modifiers. Default is the sum of the multiplications of each input and its corresponding gradient.

backward(module, grad_input, grad_output)[source]

Backward hook to compute LRP based on the class attributes.

copy()[source]

Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.

forward(module, input, output)[source]

Forward hook to save module in-/outputs.

class zennit.core.Composite(module_map=None, canonizers=None)[source]

Bases: object

A Composite to apply canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time.

Parameters:
  • module_map (callable, optional) – A function (ctx: dict, name: str, module: torch.nn.Module) -> Hook or None which maps a context, name and module to a matching Hook, or None if there is no matchin Hook.

  • canonizers (list[zennit.canonizers.Canonizer], optional) – List of canonizer instances to be applied before applying hooks.

context(module)[source]

Return a CompositeContext object with this instance and the supplied module.

Parameters:

module (torch.nn.Module) – Module for which to register this composite in the context.

Returns:

A context object which registers the composite to module on entering, and removes it on exiting.

Return type:

zennit.core.CompositeContext

inactive()[source]

Context manager to temporarily deactivate the gradient modification. This can be used to compute the gradient of the modified gradient.

register(module)[source]

Apply all canonizers and register all hooks to a module (and its recursive children). Previous canonizers of this composite are reverted and all hooks registered by this composite are removed. The module or any of its children (recursively) may still have other hooks attached.

Parameters:

module (torch.nn.Module) – Hooks and canonizers will be applied to this module recursively according to module_map and canonizers.

remove()[source]

Remove all handles for hooks and canonizers. Hooks will simply be removed from their corresponding Modules. Canonizers will revert the state of the modules they changed.

class zennit.core.CompositeContext(module, composite)[source]

Bases: object

A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

Parameters:
  • module (torch.nn.Module) – The module to which composite should be registered.

  • composite (zennit.core.Composite) – The composite which shall be registered to module.

class zennit.core.Hook[source]

Bases: object

Base class for hooks to be used to compute layer-wise attributions.

backward(module, grad_input, grad_output)[source]

Hook applied during backward-pass

copy()[source]

Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.

forward(module, input, output)[source]

Hook applied during forward-pass

post_forward(module, input, output)[source]

Register a backward-hook to the resulting tensor right after the forward.

pre_backward(module, grad_input, grad_output)[source]

Store the grad_output for the backward hook

pre_forward(module, input)[source]

Apply an Identity to the input before the module to register a backward hook.

register(module)[source]

Register this instance by registering all hooks to the supplied module.

remove()[source]

When removing hooks, remove all references to stored tensors

class zennit.core.Identity(*args, **kwargs)[source]

Bases: Function

Identity to add a grad_fn to a tensor, so a backward hook can be applied.

static backward(ctx, *grad_outputs)[source]

Backward identity.

static forward(ctx, *inputs)[source]

Forward identity.

class zennit.core.ParamMod(modifier, param_keys=None, zero_params=None, require_params=True)[source]

Bases: object

Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.

Parameters:
  • modifier (function) – A function used to modify parameter attributes. If param_keys is empty, this is not used.

  • param_keys (list[str], optional) – A list of parameter names that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.

  • zero_params (list[str], optional) – A list of parameter names that shall set to zero. If None (default), no parameters are set to zero.

  • require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.

classmethod ensure(modifier)[source]

If modifier is an instance of ParamMod, return it as-is, if it is callable, create a new instance with modifier as the ParamMod’s function, otherwise raise a TypeError.

Parameters:

modifier (ParamMod or callable) – The modifier which, if necessary, will be used to construct a ParamMod.

Returns:

Either modifier as is, or a ParamMod constructed using modifier.

Return type:

ParamMod

Raises:

TypeError – If modifier is neither an instance of ParamMod, nor callable.

class zennit.core.RemovableHandle(instance)[source]

Bases: object

Create weak reference to call .remove on some instance.

remove()[source]

Call remove on weakly reference instance if it still exists.

class zennit.core.RemovableHandleList(iterable=(), /)[source]

Bases: list

A list to hold handles, with the ability to call remove on all of its members.

remove()[source]

Call remove on all members, effectively removing handles from modules, or reverting canonizers.

class zennit.core.Stabilizer(epsilon=1e-06, clip=False, norm_scale=False, dim=None)[source]

Bases: object

Class to create a stabilizer callable.

Parameters:
  • epsilon (float, optional) – Value by which to shift/clip elements of input.

  • clip (bool, optional) – If False (default), add epsilon multiplied by each entry’s sign (+1 for 0). If True, instead clip the absolute value of input and multiply it by each entry’s original sign.

  • norm_scale (bool, optional) – If False (default), epsilon is added to/used to clip input. If True, scale epsilon by the square root of the mean over the squared elements of the specified dimensions dim.

  • dim (tuple[int], optional) – If norm_scale is True, specifies the dimension over which the scaled norm should be computed (all except dimension 0 by default).

classmethod ensure(value)[source]

Given a value, return a stabilizer. If value is a float, a Stabilizer with that epsilon value is returned. If value is callable, it will be used directly as a stabilizer. Otherwise a TypeError will be raised.

Parameters:

value (float, int, or callable) – The value used to produce a valid stabilizer function.

Returns:

A callable to be used as a stabilizer.

Return type:

callable or Stabilizer

Raises:

TypeError – If no valid stabilizer could be produced from value.

zennit.core.collect_leaves(module)[source]

Generator function to collect all leaf modules of a module.

Parameters:

module (torch.nn.Module) – A module for which the leaves will be collected.

Yields:

leaf (torch.nn.Module) – Either a leaf of the module structure, or the module itself if it has no children.

zennit.core.expand(tensor, shape, cut_batch_dim=False)[source]

Expand a scalar value or tensor to a shape. In addition to torch.Tensor.expand, this will also accept non-torch.tensor objects, which will be used to create a new tensor. If tensor has fewer dimensions than shape, singleton dimension will be appended to match the size of shape before expanding.

Parameters:
  • tensor (int, float or torch.Tensor) – Scalar or tensor to expand to the size of shape.

  • shape (tuple[int]) – Shape to which tensor will be expanded.

  • cut_batch_dim (bool, optional) – If True, take only the first shape[0] entries along dimension 0 of the expanded tensor, if it has more entries in dimension 0 than shape. Default (False) is not to cut, which will instead cause a RuntimeError due to the size mismatch.

Returns:

A new tensor expanded from tensor with shape shape.

Return type:

torch.Tensor

Raises:

RuntimeError – If tensor could not be expanded to shape due to incompatible shapes.

zennit.core.stabilize(input, epsilon=1e-06, clip=False, norm_scale=False, dim=None)[source]

Stabilize input for safe division.

Parameters:
  • input (torch.Tensor) – Tensor to stabilize.

  • epsilon (float, optional) – Value by which to shift/clip elements of input.

  • clip (bool, optional) – If False (default), add epsilon multiplied by each entry’s sign (+1 for 0). If True, instead clip the absolute value of input and multiply it by each entry’s original sign.

  • norm_scale (bool, optional) – If False (default), epsilon is added to/used to clip input. If True, scale epsilon by the square root of the mean over the squared elements of the specified dimensions dim.

  • dim (tuple[int], optional) – If norm_scale is True, specifies the dimension over which the scaled norm should be computed. Defaults to all except dimension 0.

Returns:

New Tensor copied from input with values shifted by epsilon.

Return type:

torch.Tensor

zennit.core.zero_wrap(zero_params)[source]

Create a function wrapper factory (i.e. a decorator), which takes a single function argument (name, param) -> tensor such that the function is only called if name is not equal to zero_params, if zero_params is a string, or it is not in zero_params. Otherwise return torch.zeros_like of that tensor.

Parameters:

zero_params (str or list[str]) – String or list of strings compared to name.

Returns:

The function wrapper to be called on the function.

Return type:

function