-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Assign rule by layer index #76
Comments
A note about the code from the tutorial: While it would be straight forward to write a composite which does take the layer index into account, I would suggest to rather go with the If you want to automatically create a name-map as the composite from the tutorial/ paper, here's some code on how to do it: Click to unfold codeimport torch
from torch.nn import Conv2d, AvgPool2d
from torchvision.models import vgg16
from zennit.composites import NameMapComposite
from zennit.core import BasicHook, collect_leaves, stabilize
from zennit.rules import Gamma, Epsilon
# the LRP-Epsilon from the tutorial
class GMontavonEpsilon(BasicHook):
def __init__(self, epsilon=1e-6, delta=0.25):
super().__init__(
input_modifiers=[lambda input: input],
param_modifiers=[lambda param, _: param],
output_modifiers=[lambda output: output],
gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(outputs[0] + delta * (outputs[0] ** 2).mean() ** .5, epsilon)),
reducer=(lambda inputs, gradients: inputs[0] * gradients[0])
)
model = vgg16()
# only these get rules, linear layers will be attributed by the gradient alone
target_types = (Conv2d, AvgPool2d)
# lookup module -> name
child_name = {module: name for name, module in model.named_modules()}
# the layers in sequential order without any containers etc.
layers = list(enumerate(collect_leaves(model)))
# list of tuples [([names..], rule)] as used by NameMapComposite
name_map = [
([child_name[module] for n, module in layers if n <= 16 and isinstance(module, target_types)], Gamma(0.25)),
([child_name[module] for n, module in layers if 17 <= n <= 30 and isinstance(module, target_types)], GMontavonEpsilon(1e-9, 0.25)),
([child_name[module] for n, module in layers if 30 <= n and isinstance(module, target_types)], Epsilon(1e-9)),
]
# look at the name_map and you will see that there is no layer for which the last condition holds
print(name_map)
# create the composite from the name map
composite = NameMapComposite(name_map)
with composite.context(model) as modified_model:
# compute attribution
data = torch.randn(1, 3, 224, 224, requires_grad=True)
output = modified_model(data)
output.backward(torch.eye(1000)[[0]])
# print absolute sum of attribution
print(data.grad.abs().sum().item()) Note that doing |
@chr5tphr Awesome! Thank you for your time, the snippet has been of great help. Clarification about the tutorial:
With your help, I've been able to reproduce the results from the LRP tutorial; here are the changes I made:
|
Reproduce gmontavon/lrp-tutorial with zennit framework. Related issue chr5tphr/zennit#76.
I'd like to assign LRP-rules by layer index, like shown in the screenshots.
Please correct me if this is already possible, I've taken a look at the code and the paper but it seems currently it's only possible by type.
Source: gmontavon/lrp-tutorial
Source: Layer-Wise Relevance Propagation: An Overview
Thanks for the great framework! I especially like its architecture.
The text was updated successfully, but these errors were encountered: