diff --git a/Project.toml b/Project.toml index 2cec830..54ed97e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "XAIBase" uuid = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" authors = ["Adrian Hill "] -version = "2.0.0-DEV" +version = "1.1.0-DEV" [deps] TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2" diff --git a/src/XAIBase.jl b/src/XAIBase.jl index 8c797a0..6a810c8 100644 --- a/src/XAIBase.jl +++ b/src/XAIBase.jl @@ -26,9 +26,12 @@ include("analyze.jl") # Heatmapping for vision and NLP tasks. include("heatmaps.jl") +# To be removed in next breaking release: +include("deprecated.jl") + export AbstractXAIMethod export AbstractNeuronSelector export Explanation export analyze -export heatmap, textheatmap +export heatmap end #module diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 0000000..e3c2e6f --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,10 @@ +function Explanation(val, output, output_selection, analyzer::Symbol) + @warn "Creating an Explanation without a heatmap style is being deprecated. Defaulting to `heatmap=:attribution`." + return Explanation(val, output, output_selection, analyzer, :attribution, nothing) +end +function Explanation( + val, output, output_selection, analyzer::Symbol, extras::Union{Nothing,NamedTuple} +) + @warn "Creating an Explanation without a heatmap style is being deprecated. Defaulting to `heatmap=:attribution`." + return Explanation(val, output, output_selection, analyzer, :attribution, extras) +end diff --git a/src/explanation.jl b/src/explanation.jl index adb9a57..32c11b2 100644 --- a/src/explanation.jl +++ b/src/explanation.jl @@ -19,6 +19,6 @@ struct Explanation{V,O,S,E<:Union{Nothing,NamedTuple}} heatmap::Symbol extras::E end -function Explanation(val, output, output_selection, analyzer, heatmap) +function Explanation(val, output, output_selection, analyzer::Symbol, heatmap::Symbol) return Explanation(val, output, output_selection, analyzer, heatmap, nothing) end diff --git a/src/heatmaps.jl b/src/heatmaps.jl index 6a29ba9..1684efa 100644 --- a/src/heatmaps.jl +++ b/src/heatmaps.jl @@ -43,8 +43,8 @@ Visualize `Explanation` from XAIBase as a vision heatmap. Assumes WHCN convention (width, height, channels, batchsize) for `explanation.val`. ## Keyword arguments -- `colorscheme::Union{ColorScheme,Symbol}`: Color scheme from ColorSchemes.jl. - Defaults to `seismic`. +- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl. + Defaults to `:$DEFAULT_COLORSCHEME`. - `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme. The following methods can be selected, which are then applied over the color channels for each "pixel" in the array: @@ -80,7 +80,7 @@ end #===============# """ - textheatmap(explanation, text) + heatmap(explanation, text) Visualize `Explanation` from XAIBase as text heatmap. Text should be a vector containing vectors of strings, one for each input in the batched explanation. @@ -92,7 +92,7 @@ Text should be a vector containing vectors of strings, one for each input in the before the color scheme is applied. Can be either `:extrema` or `:centered`. Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`. """ -function textheatmap( +function heatmap( expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs... ) ndims(expl.val) != 2 && throw( @@ -113,6 +113,6 @@ function textheatmap( ] end -function textheatmap(expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...) - return textheatmap(expl, [text]; kwargs...) +function heatmap(expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...) + return heatmap(expl, [text]; kwargs...) end diff --git a/test/runtests.jl b/test/runtests.jl index 5dc09ac..530e54a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,9 @@ using Aqua @info "Testing text heatmaps..." include("test_textheatmap.jl") end + @testset "Deprecations" begin + # To be removed in next breaking release + @info "Testing deprecations..." + include("test_deprecated.jl") + end end diff --git a/test/test_deprecated.jl b/test/test_deprecated.jl new file mode 100644 index 0000000..b09bf30 --- /dev/null +++ b/test/test_deprecated.jl @@ -0,0 +1,6 @@ +# Test deprecation warnings +shape = (2, 2, 3, 1) +val = output = reshape(collect(Float32, 1:prod(shape)), shape) +neuron_selection = [CartesianIndex(1, 2)] # irrelevant +@test_logs (:warn,) expl = Explanation(val, output, [neuron_selection], :LRP) +@test_logs (:warn,) expl = Explanation(val, output, [neuron_selection], :LRP, nothing) diff --git a/test/test_textheatmap.jl b/test/test_textheatmap.jl index 451d479..2468f18 100644 --- a/test/test_textheatmap.jl +++ b/test/test_textheatmap.jl @@ -2,21 +2,21 @@ val = output = [1 6; 2 5; 3 4] text = [["Test", "Text", "Heatmap"], ["another", "dummy", "input"]] neuron_selection = [CartesianIndex(1, 2), CartesianIndex(3, 4)] # irrelevant expl = Explanation(val, output, neuron_selection, :Gradient, :sensitivity) -h = textheatmap(expl, text) +h = heatmap(expl, text) @test_reference "references/Gradient1.txt" repr("text/plain", h[1]) @test_reference "references/Gradient2.txt" repr("text/plain", h[2]) expl = Explanation( val[:, 1:1], output[:, 1:1], neuron_selection[1], :Gradient, :sensitivity ) -h = textheatmap(expl, text[1]) +h = heatmap(expl, text[1]) @test_reference "references/Gradient1.txt" repr("text/plain", only(h)) expl = Explanation(val, output, neuron_selection, :LRP, :attribution) -h = textheatmap(expl, text) +h = heatmap(expl, text) @test_reference "references/LRP1.txt" repr("text/plain", h[1]) @test_reference "references/LRP2.txt" repr("text/plain", h[2]) -h = textheatmap(expl, text; rangescale=:extrema) +h = heatmap(expl, text; rangescale=:extrema) @test_reference "references/LRP1_extrema.txt" repr("text/plain", h[1]) @test_reference "references/LRP2_extrema.txt" repr("text/plain", h[2])