zennit.core

Core functions and classes

Functions

collect_leaves

Generator function to collect all leaf modules of a module.

expand

Expand a scalar value or tensor to a shape.

stabilize

Stabilize input for safe division.

uncompress

Generator which, given a compressed iterable produced by itertools.compress and (some iterable similar to) the original data and selector used for compress, yields values from compressed or data depending on selector.

zero_wrap

Create a function wrapper factory (i.e. a decorator), which takes a single function argument (name, param) -> tensor such that the function is only called if name is not equal to zero_params, if zero_params is a string, or it is not in zero_params.

Classes

BasicHook

A hook to compute the layer-wise attribution of the module it is attached to.

Composite

A Composite to apply canonizers and register hooks to modules.

CompositeContext

A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

Hook

Base class for hooks to be used to compute layer-wise attributions.

Identity

Identity to add a grad_fn to a tensor, so a backward hook can be applied.

ParamMod

Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.

RemovableHandle

Create weak reference to call .remove on some instance.

RemovableHandleList

A list to hold handles, with the ability to call remove on all of its members.

Stabilizer

Class to create a stabilizer callable.

class zennit.core.Stabilizer[source]

Bases: object

Class to create a stabilizer callable.

Parameters:
  • epsilon (float, optional) – Value by which to shift/clip elements of input.

  • clip (bool, optional) – If False (default), add epsilon multiplied by each entry’s sign (+1 for 0). If True, instead clip the absolute value of input and multiply it by each entry’s original sign.

  • norm_scale (bool, optional) – If False (default), epsilon is added to/used to clip input. If True, scale epsilon by the square root of the mean over the squared elements of the specified dimensions dim.

  • dim (tuple[int], optional) – If norm_scale is True, specifies the dimension over which the scaled norm should be computed (all except dimension 0 by default).

__init__(epsilon=1e-6, clip=False, norm_scale=False, dim=None)[source]
classmethod ensure(value)[source]

Given a value, return a stabilizer. If value is a float, a Stabilizer with that epsilon value is returned. If value is callable, it will be used directly as a stabilizer. Otherwise a TypeError will be raised.

Parameters:

value (float, int, or callable) – The value used to produce a valid stabilizer function.

Returns:

A callable to be used as a stabilizer.

Return type:

callable or Stabilizer

Raises:

TypeError – If no valid stabilizer could be produced from value.

zennit.core.stabilize(input, epsilon=1e-6, clip=False, norm_scale=False, dim=None)[source]

Stabilize input for safe division.

Parameters:
  • input (torch.Tensor) – Tensor to stabilize.

  • epsilon (float, optional) – Value by which to shift/clip elements of input.

  • clip (bool, optional) – If False (default), add epsilon multiplied by each entry’s sign (+1 for 0). If True, instead clip the absolute value of input and multiply it by each entry’s original sign.

  • norm_scale (bool, optional) – If False (default), epsilon is added to/used to clip input. If True, scale epsilon by the square root of the mean over the squared elements of the specified dimensions dim.

  • dim (tuple[int], optional) – If norm_scale is True, specifies the dimension over which the scaled norm should be computed. Defaults to all except dimension 0.

Returns:

New Tensor copied from input with values shifted by epsilon.

Return type:

torch.Tensor

zennit.core.expand(tensor, shape, cut_batch_dim=False)[source]

Expand a scalar value or tensor to a shape. In addition to torch.Tensor.expand, this will also accept non-torch.tensor objects, which will be used to create a new tensor. If tensor has fewer dimensions than shape, singleton dimension will be appended to match the size of shape before expanding.

Parameters:
  • tensor (int, float or torch.Tensor) – Scalar or tensor to expand to the size of shape.

  • shape (tuple[int]) – Shape to which tensor will be expanded.

  • cut_batch_dim (bool, optional) – If True, take only the first shape[0] entries along dimension 0 of the expanded tensor, if it has more entries in dimension 0 than shape. Default (False) is not to cut, which will instead cause a RuntimeError due to the size mismatch.

Returns:

A new tensor expanded from tensor with shape shape.

Return type:

torch.Tensor

Raises:

RuntimeError – If tensor could not be expanded to shape due to incompatible shapes.

zennit.core.zero_wrap(zero_params)[source]

Create a function wrapper factory (i.e. a decorator), which takes a single function argument (name, param) -> tensor such that the function is only called if name is not equal to zero_params, if zero_params is a string, or it is not in zero_params. Otherwise return torch.zeros_like of that tensor.

Parameters:

zero_params (str or list[str]) – String or list of strings compared to name.

Returns:

The function wrapper to be called on the function.

Return type:

function

zennit.core.uncompress(data, selector, compressed) Generator[source]

Generator which, given a compressed iterable produced by itertools.compress and (some iterable similar to) the original data and selector used for compress, yields values from compressed or data depending on selector. True values in selector skip data one ahead and yield a value from compressed, while False values yield one value from data.

Parameters:
  • data (iterable) – The iterable (similar to the) original data. False values in the selector will be filled with values from this iterator, while True values will cause this iterable to be skipped.

  • selector (iterable of bool) – The original selector used to produce compressed. Chooses whether elements from data or from compressed will be yielded.

  • compressed (iterable) – The results of itertools.compress. Will be yielded for each True element in selector.

Yields:

object – An element of data if the associated element of selector is False, otherwise an element of compressed while skipping data one ahead.

Return type:

Generator

class zennit.core.ParamMod[source]

Bases: object

Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.

Parameters:
  • modifier (function) – A function used to modify parameter attributes. If param_keys is empty, this is not used.

  • param_keys (list[str], optional) – A list of parameter names that shall be modified. If None (default), all parameters are modified (which may be none). If [], no parameters are modified and modifier is ignored.

  • zero_params (list[str], optional) – A list of parameter names that shall set to zero. If None (default), no parameters are set to zero.

  • require_params (bool, optional) – Whether existence of module’s params is mandatory (True by default). If the attribute exists but is None, it is not considered missing, and the modifier is not applied.

__init__(modifier, param_keys=None, zero_params=None, require_params=True)[source]
state_dicts(module)[source]

Returns a state_dict of the modified module parameters.

Parameters:

module (torch.nn.Module) – The module for which parameters shall be modified.

Returns:

  • original_state (dict of torch.Tensor) – The original, unmodified parameters.

  • modified_state (dict of torch.Tensor) – The modified parameters.

Raises:

RuntimeError – If parameters are missing and self.require_params has been set to True.

classmethod ensure(modifier)[source]

If modifier is an instance of ParamMod, return it as-is, if it is callable, create a new instance with modifier as the ParamMod’s function, otherwise raise a TypeError.

Parameters:

modifier (ParamMod or callable) – The modifier which, if necessary, will be used to construct a ParamMod.

Returns:

Either modifier as is, or a ParamMod constructed using modifier.

Return type:

ParamMod

Raises:

TypeError – If modifier is neither an instance of ParamMod, nor callable.

zennit.core.collect_leaves(module) Iterator[Module][source]

Generator function to collect all leaf modules of a module.

Parameters:

module (torch.nn.Module) – A module for which the leaves will be collected.

Yields:

leaf (torch.nn.Module) – Either a leaf of the module structure, or the module itself if it has no children.

Return type:

Iterator[Module]

class zennit.core.Identity[source]

Bases: Function

Identity to add a grad_fn to a tensor, so a backward hook can be applied.

static forward(ctx, *inputs)[source]

Forward identity.

Parameters:
  • ctx (object) – The function context.

  • *inputs (tuple of torch.Tensor) – Inputs to forward.

Returns:

inputs – The unmodified inputs.

Return type:

tuple of torch.Tensor

static backward(ctx, *grad_outputs)[source]

Backward identity.

Parameters:
  • ctx (object) – The function context.

  • *grad_outputs (tuple of torch.Tensor) – Output gradients.

Returns:

grad_outputs – The unmodified output gradients.

Return type:

tuple of torch.Tensor

class zennit.core.Hook[source]

Bases: object

Base class for hooks to be used to compute layer-wise attributions.

__init__()[source]
pre_forward(module, args, kwargs)[source]

Apply an Identity to the input before the module to register a backward hook.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • args (tuple of torch.Tensor) – The input tensors passed to module.forward.

  • kwargs (dict) – The keyword arguments passed to module.forward.

Returns:

A tuple of the modified input tensors.

Return type:

tuple of torch.Tensor, optional

post_forward(module, args, kwargs, output)[source]

Register a backward-hook to the resulting tensor right after the forward.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • args (tuple of torch.Tensor) – The input tensors passed to module.forward.

  • kwargs (tuple of object) – The keyword arguments passed to module.forward.

  • output (torch.Tensor) – The output tensor.

Returns:

A tuple of the modified output tensors.

Return type:

tuple of torch.Tensor, optional

pre_backward(module, grad_input, grad_output)[source]

Store the grad_output for the backward hook.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • grad_input (torch.Tensor) – The input gradient tensor.

  • grad_output (torch.Tensor) – The output gradient tensor.

forward(module, args, kwargs, output)[source]

Hook applied during forward-pass.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • args (tuple of torch.Tensor) – The input tensors passed to module.forward.

  • kwargs (tuple of object) – The keyword arguments passed to module.forward.

  • output (torch.Tensor) – The output tensor.

backward(module, grad_input, grad_output)[source]

Hook applied during backward-pass.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • grad_input (torch.Tensor) – The input gradient tensor.

  • grad_output (torch.Tensor) – The output gradient tensor.

copy()[source]

Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.

Returns:

A copy of this hook.

Return type:

BasicHook

remove()[source]

When removing hooks, remove all references to stored tensors.

register(module)[source]

Register this instance by registering all hooks to the supplied module.

Parameters:

module (torch.nn.Module) – The module to which to register to.

Returns:

A list of removable handles, one for each registered hook.

Return type:

RemovableHandleList

class zennit.core.BasicHook[source]

Bases: Hook

A hook to compute the layer-wise attribution of the module it is attached to. A BasicHook instance may only be registered with a single module.

Parameters:
  • input_modifiers (list[callable], optional) – A list of functions (input: torch.Tensor) -> torch.Tensor to produce multiple inputs. Default is a single input which is the identity.

  • param_modifiers (list[ParamMod or callable], optional) – A list of ParamMod instances or functions (obj: torch.Tensor, name: str) -> torch.Tensor, with parameter tensor obj, registered in the root model as name, to temporarily modify the parameters of the attached module for each input produced with input_modifiers. Default is unmodified parameters for each input. Use a ParamMod instance to specify which parameters should be modified, whether they are required, and which should be set to zero.

  • output_modifiers (list[callable], optional) – A list of functions (input: torch.Tensor) -> torch.Tensor to modify the module’s output computed using the modified parameters before gradient computation for each input produced with input_modifier. Default is the identity for each output.

  • gradient_mapper (callable, optional) – Function (out_grad: torch.Tensor, outputs: list[torch.Tensor]) -> list[torch.Tensor] to modify upper relevance. A list or tuple of the same size as outputs is expected to be returned. outputs has the same size as input_modifiers and param_modifiers. Default is a stabilized normalization by each of the outputs, multiplied with the output gradient.

  • reducer (callable) – Function (inputs: list[torch.Tensor], gradients: list[torch.Tensor]) -> torch.Tensor to reduce all the inputs and gradients produced through input_modifiers and param_modifiers. inputs and gradients have the same as input_modifiers and param_modifiers. Default is the sum of the multiplications of each input and its corresponding gradient.

  • stabilizer (callable or float, optional) – Stabilization parameter for rules other than Epsilon. If stabilizer is a float, it will be added to the denominator with the same sign as each respective entry. If it is callable, a function (input: torch.Tensor) -> torch.Tensor is expected, of which the output corresponds to the stabilized denominator.

__init__(input_modifiers=None, param_modifiers=None, output_modifiers=None, gradient_mapper=None, reducer=None, stabilizer=1e-6)[source]
forward(module, args, kwargs, output)[source]

Forward hook to save module in-/outputs.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • args (tuple of torch.Tensor) – The input tensors passed to module.forward.

  • kwargs (tuple of object) – The keyword arguments passed to module.forward.

  • output (torch.Tensor) – The output tensor.

backward(module, grad_input, grad_output)[source]

Backward hook to compute LRP based on the class attributes.

Parameters:
  • module (torch.nn.Module) – The module to which this hook is attached.

  • grad_input (torch.Tensor) – The input gradient tensor.

  • grad_output (torch.Tensor) – The output gradient tensor.

Returns:

The modified input gradient tensors.

Return type:

tuple of torch.nn.Module

copy()[source]

Return a copy of this hook. This is used to describe hooks of different modules by a single hook instance.

Returns:

A copy of this hook.

Return type:

BasicHook

class zennit.core.RemovableHandle[source]

Bases: object

Create weak reference to call .remove on some instance.

Parameters:

instance (object) – The instance to which to create the reference.

__init__(instance)[source]
remove()[source]

Call remove on weakly reference instance if it still exists.

class zennit.core.RemovableHandleList[source]

Bases: list

A list to hold handles, with the ability to call remove on all of its members.

remove()[source]

Call remove on all members, effectively removing handles from modules, or reverting canonizers.

class zennit.core.CompositeContext[source]

Bases: object

A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

Parameters:
  • module (torch.nn.Module) – The module to which composite should be registered.

  • composite (zennit.core.Composite) – The composite which shall be registered to module.

__init__(module, composite)[source]
class zennit.core.Composite[source]

Bases: object

A Composite to apply canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time.

Parameters:
  • module_map (callable, optional) – A function (ctx: dict, name: str, module: torch.nn.Module) -> Hook or None which maps a context, name and module to a matching Hook, or None if there is no matchin Hook.

  • canonizers (list[zennit.canonizers.Canonizer], optional) – List of canonizer instances to be applied before applying hooks.

__init__(module_map=None, canonizers=None)[source]
register(module)[source]

Apply all canonizers and register all hooks to a module (and its recursive children). Previous canonizers of this composite are reverted and all hooks registered by this composite are removed. The module or any of its children (recursively) may still have other hooks attached.

Parameters:

module (torch.nn.Module) – Hooks and canonizers will be applied to this module recursively according to module_map and canonizers.

remove()[source]

Remove all handles for hooks and canonizers. Hooks will simply be removed from their corresponding Modules. Canonizers will revert the state of the modules they changed.

context(module)[source]

Return a CompositeContext object with this instance and the supplied module.

Parameters:

module (torch.nn.Module) – Module for which to register this composite in the context.

Returns:

A context object which registers the composite to module on entering, and removes it on exiting.

Return type:

zennit.core.CompositeContext

inactive() Generator[source]

Context manager to temporarily deactivate the gradient modification. This can be used to compute the gradient of the modified gradient.

Yields:

self (Composite) – The instance of this composite.

Return type:

Generator