Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to skip normalization of output layer relevance #22

Merged
merged 7 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/crp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ end
function call_analyzer(
input::AbstractArray{T,N}, crp::CRP, ns::AbstractOutputSelector
) where {T,N}
rules = crp.lrp.rules
layers = crp.lrp.model.layers
modified_layers = crp.lrp.modified_layers
# Unpack internal LRP analyzer
(; model, rules, modified_layers, normalize_output_relevance) = crp.lrp
layers = model.layers

n_layers = length(layers)
n_features = number_of_features(crp.features)
batchsize = size(input, N)

# Forward pass
as = get_activations(crp.lrp.model, input) # compute activations aᵏ for all layers k
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
mask_output_neuron!(Rs[end], as[end], ns, normalize_output_relevance) # compute relevance Rᴺ of output layer N

# Allocate array for returned relevance, adding features to batch dimension
R_return = similar(input, size(input)[1:(end - 1)]..., batchsize * n_features)
Expand Down
31 changes: 22 additions & 9 deletions src/lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ The analyzer can either be created by passing an array of LRP-rules
or by passing a composite, see [`Composite`](@ref) for an example.

# Keyword arguments
- `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Default is `false`.
- `verbose::Bool`: Select whether the model checks should print a summary on failure. Default is `true`.
- `normalize_output_relevance`: Selects whether output relevance should be set to 1 before applying LRP backward pass.
Defaults to `true` to match literature. If `false`, values of output activations are used.
- `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Defaults to `false`.
- `verbose::Bool`: Select whether the model checks should print a summary on failure. Defaults to `true`.

# References
[1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
Expand All @@ -22,10 +24,16 @@ struct LRP{C<:Chain,R<:ChainTuple,L<:ChainTuple} <: AbstractXAIMethod
model::C
rules::R
modified_layers::L
normalize_output_relevance::Bool

# Construct LRP analyzer by assigning a rule to each layer
function LRP(
model::Chain, rules::ChainTuple; skip_checks=false, flatten=true, verbose=true
model::Chain,
rules::ChainTuple;
normalize_output_relevance::Bool=true,
skip_checks=false,
flatten=true,
verbose=true,
)
if flatten
model = chainflatten(model)
Expand All @@ -37,7 +45,7 @@ struct LRP{C<:Chain,R<:ChainTuple,L<:ChainTuple} <: AbstractXAIMethod
end
modified_layers = get_modified_layers(rules, model)
return new{typeof(model),typeof(rules),typeof(modified_layers)}(
model, rules, modified_layers
model, rules, modified_layers, normalize_output_relevance
)
end
end
Expand All @@ -59,20 +67,25 @@ function call_analyzer(
input::AbstractArray, lrp::LRP, ns::AbstractOutputSelector; layerwise_relevances=false
)
as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N

Rs = similar.(as)
mask_output_neuron!(Rs[end], as[end], ns, lrp.normalize_output_relevance) # compute relevance Rᴺ of output layer N
lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers)
extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing
return Explanation(first(Rs), input, last(as), ns(last(as)), :LRP, :attribution, extras)
end

get_activations(model, input) = (input, Flux.activations(model, input)...)

function mask_output_neuron!(R_out, a_out, ns::AbstractOutputSelector)
function mask_output_neuron!(
R_out, a_out, ns::AbstractOutputSelector, normalize_output_relevance::Bool
)
fill!(R_out, 0)
idx = ns(a_out)
R_out[idx] .= 1
if normalize_output_relevance
R_out[idx] .= 1
else
R_out[idx] .= a_out[idx]
end
return R_out
end

Expand Down
23 changes: 23 additions & 0 deletions test/test_batches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ for (name, method) in ANALYZERS
@test expl2_bd.val ≈ expl_batch.val[:, 2]
end
end

@testset "Normalized output relevance" begin
analyzer1 = LRP(model)
analyzer2 = LRP(model; normalize_output_relevance=false)

e1 = analyze(input_batch, analyzer1)
e2 = analyze(input_batch, analyzer2)
v1_bd1 = e1.val[:, 1]
v1_bd2 = e1.val[:, 2]
v2_bd1 = e2.val[:, 1]
v2_bd2 = e2.val[:, 2]

@test isapprox(sum(v1_bd1), 1, atol=0.05)
@test isapprox(sum(v1_bd2), 1, atol=0.05)
@test !isapprox(sum(v2_bd1), 1; atol=0.05)
@test !isapprox(sum(v2_bd2), 1; atol=0.05)

ratio_bd1 = first(v1_bd1) / first(v2_bd1)
ratio_bd2 = first(v1_bd2) / first(v2_bd2)
@test !isapprox(ratio_bd1, ratio_bd2)
@test v1_bd1 ≈ v2_bd1 * ratio_bd1
@test v1_bd2 ≈ v2_bd2 * ratio_bd2
end
15 changes: 15 additions & 0 deletions test/test_cnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,18 @@ end
@test lwr1[1] ≈ lwr2[1]
@test lwr1[end] ≈ lwr2[end]
end

@testset "Normalized output relevance" begin
analyzer1 = LRP(model)
analyzer2 = LRP(model; normalize_output_relevance=false)

e1 = analyze(input, analyzer1)
e2 = analyze(input, analyzer2)
v1, v2 = e1.val, e2.val

@test isapprox(sum(v1), 1, atol=0.05)
@test !isapprox(sum(v2), 1; atol=0.05)

ratio = first(v1) / first(v2)
@test v1 ≈ v2 * ratio
end
Loading