From 9088127cdc93f5f851f3fcde116259545e121917 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 10 Oct 2024 15:03:32 +0200 Subject: [PATCH 1/5] Remove support for Julia <1.10 --- Project.toml | 6 +++--- README.md | 2 +- src/RelevancePropagation.jl | 1 - src/compat.jl | 6 ------ 4 files changed, 4 insertions(+), 11 deletions(-) delete mode 100644 src/compat.jl diff --git a/Project.toml b/Project.toml index b28d12e..4b34ab1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RelevancePropagation" uuid = "0be6dd02-ae9e-43eb-b318-c6e81d6890d8" authors = ["Adrian Hill "] -version = "2.0.3-DEV" +version = "3.0.0-DEV" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -20,6 +20,6 @@ Markdown = "1" Random = "1" Reexport = "1" Statistics = "1" -XAIBase = "3" +XAIBase = "4" Zygote = "0.6" -julia = "1.6" +julia = "1.10" diff --git a/README.md b/README.md index e33953f..81434fd 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ This package is part of the [Julia-XAI ecosystem](https://github.com/Julia-XAI) [ExplainableAI.jl](https://github.com/Julia-XAI/ExplainableAI.jl). ## Installation -This package supports Julia ≥1.6. To install it, open the Julia REPL and run +This package supports Julia ≥1.10. To install it, open the Julia REPL and run ```julia-repl julia> ]add RelevancePropagation ``` diff --git a/src/RelevancePropagation.jl b/src/RelevancePropagation.jl index b914dc5..755601b 100644 --- a/src/RelevancePropagation.jl +++ b/src/RelevancePropagation.jl @@ -12,7 +12,6 @@ using Zygote using Markdown using Statistics: mean, std -include("compat.jl") include("bibliography.jl") include("layer_types.jl") include("layer_utils.jl") diff --git a/src/compat.jl b/src/compat.jl deleted file mode 100644 index 368cdca..0000000 --- a/src/compat.jl +++ /dev/null @@ -1,6 +0,0 @@ -if VERSION < v"1.8.0-DEV.1494" # 98e60ffb11ee431e462b092b48a31a1204bd263d - export allequal - allequal(itr) = isempty(itr) ? true : all(isequal(first(itr)), itr) - allequal(c::Union{AbstractSet,AbstractDict}) = length(c) <= 1 - allequal(r::AbstractRange) = iszero(step(r)) || length(r) <= 1 -end From 238c65b1718281a9c87511693a9b83f704513016 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 10 Oct 2024 15:03:47 +0200 Subject: [PATCH 2/5] Update XAIBase to v4 --- src/RelevancePropagation.jl | 1 + src/crp.jl | 6 ++++-- src/lrp.jl | 6 +++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/RelevancePropagation.jl b/src/RelevancePropagation.jl index 755601b..4930612 100644 --- a/src/RelevancePropagation.jl +++ b/src/RelevancePropagation.jl @@ -2,6 +2,7 @@ module RelevancePropagation using Reexport @reexport using XAIBase +import XAIBase: call_analyzer using XAIBase: AbstractFeatureSelector, number_of_features using Base.Iterators diff --git a/src/crp.jl b/src/crp.jl index 9d57790..6964265 100644 --- a/src/crp.jl +++ b/src/crp.jl @@ -32,7 +32,9 @@ end # Call to CRP analyzer # #======================# -function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractOutputSelector) where {T,N} +function call_analyzer( + input::AbstractArray{T,N}, crp::CRP, ns::AbstractOutputSelector +) where {T,N} rules = crp.lrp.rules layers = crp.lrp.model.layers modified_layers = crp.lrp.modified_layers @@ -88,5 +90,5 @@ function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractOutputSelector) where end end end - return Explanation(R_return, last(as), ns(last(as)), :CRP, :attribution, nothing) + return Explanation(R_return, input, last(as), ns(last(as)), :CRP, :attribution, nothing) end diff --git a/src/lrp.jl b/src/lrp.jl index 758264d..f3d0be1 100644 --- a/src/lrp.jl +++ b/src/lrp.jl @@ -55,8 +55,8 @@ LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwa # Call to the LRP analyzer # #==========================# -function (lrp::LRP)( - input::AbstractArray, ns::AbstractOutputSelector; layerwise_relevances=false +function call_analyzer( + input::AbstractArray, lrp::LRP, ns::AbstractOutputSelector; layerwise_relevances=false ) as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k Rs = similar.(as) # allocate relevances Rᵏ for all layers k @@ -64,7 +64,7 @@ function (lrp::LRP)( lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers) extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing - return Explanation(first(Rs), last(as), ns(last(as)), :LRP, :attribution, extras) + return Explanation(first(Rs), input, last(as), ns(last(as)), :LRP, :attribution, extras) end get_activations(model, input) = (input, Flux.activations(model, input)...) From 7d420eb8429407fab5ee232de2372b524e147d2a Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 10 Oct 2024 15:12:46 +0200 Subject: [PATCH 3/5] Drop Flux `v0.13` --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4b34ab1..46ccaab 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,7 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Flux = "0.13, 0.14" +Flux = "0.14" MacroTools = "0.5" Markdown = "1" Random = "1" From 169fc50b72d1b1ae77d3ff834ff1722799882006 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 10 Oct 2024 15:13:03 +0200 Subject: [PATCH 4/5] `add_batch_dim` got removed in XAIBase `v4` --- test/test_batches.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/test_batches.jl b/test/test_batches.jl index 74c063a..5cbbec4 100644 --- a/test/test_batches.jl +++ b/test/test_batches.jl @@ -30,14 +30,6 @@ ANALYZERS = Dict( for (name, method) in ANALYZERS @testset "$name" begin - # Using `add_batch_dim=true` should result in same explanation - # as input reshaped to have a batch dimension - analyzer = method(model) - expl1_no_bd = analyzer(input1_no_bd; add_batch_dim=true) - analyzer = method(model) - expl1_bd = analyzer(input1_bd) - @test expl1_bd.val ≈ expl1_no_bd.val - # Analyzing a batch should have the same result # as analyzing inputs in batch individually analyzer = method(model) From 1ac5f5194435d6d61c1f2411018dee5946a0b077 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 10 Oct 2024 15:29:14 +0200 Subject: [PATCH 5/5] Fix batch test --- test/test_batches.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_batches.jl b/test/test_batches.jl index 5cbbec4..09c92dc 100644 --- a/test/test_batches.jl +++ b/test/test_batches.jl @@ -36,6 +36,6 @@ for (name, method) in ANALYZERS expl2_bd = analyzer(input2_bd) analyzer = method(model) expl_batch = analyzer(input_batch) - @test expl1_bd.val ≈ expl_batch.val[:, 1] + @test expl2_bd.val ≈ expl_batch.val[:, 2] end end