diff --git a/src/input_augmentation.jl b/src/input_augmentation.jl index 02a6675..026b441 100644 --- a/src/input_augmentation.jl +++ b/src/input_augmentation.jl @@ -8,80 +8,6 @@ struct AugmentationSelector{I} <: AbstractOutputSelector end (s::AugmentationSelector)(out) = s.indices -""" - augment_batch_dim(input, n) - -Repeat each sample in input batch n-times along batch dimension. -This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`. - -## Example -```julia-repl -julia> A = [1 2; 3 4] -2×2 Matrix{Int64}: - 1 2 - 3 4 - -julia> augment_batch_dim(A, 3) -2×6 Matrix{Int64}: - 1 1 1 2 2 2 - 3 3 3 4 4 4 -``` -""" -function augment_batch_dim(input::AbstractArray{T,N}, n) where {T,N} - return repeat(input; inner=(ntuple(Returns(1), N - 1)..., n)) -end - -""" - reduce_augmentation(augmented_input, n) - -Reduce augmented input batch by averaging the explanation for each augmented sample. -""" -function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFloat,N} - # Allocate output array - in_size = size(input) - in_size[end] % n != 0 && - throw(ArgumentError("Can't reduce augmented batch size of $(in_size[end]) by $n")) - out_size = (in_size[1:(end - 1)]..., div(in_size[end], n)) - out = similar(input, eltype(input), out_size) - - axs = axes(input, N) - colons = ntuple(Returns(:), N - 1) - for (i, ax) in enumerate(first(axs):n:last(axs)) - view(out, colons..., i) .= - dropdims(sum(view(input, colons..., ax:(ax + n - 1)); dims=N); dims=N) / n - end - return out -end - -""" - augment_indices(indices, n) - -Strip batch indices and return indices for batch augmented by n samples. - -## Example -```julia-repl -julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)] -2-element Vector{CartesianIndex{2}}: - CartesianIndex(5, 1) - CartesianIndex(3, 2) - -julia> augment_indices(inds, 3) -6-element Vector{CartesianIndex{2}}: - CartesianIndex(5, 1) - CartesianIndex(5, 2) - CartesianIndex(5, 3) - CartesianIndex(3, 4) - CartesianIndex(3, 5) - CartesianIndex(3, 6) -``` -""" -function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N} - indices_wo_batch = [i.I[1:(end - 1)] for i in inds] - return map(enumerate(repeat(indices_wo_batch; inner=n))) do (i, idx) - CartesianIndex{N}(idx..., i) - end -end - """ NoiseAugmentation(analyzer, n) NoiseAugmentation(analyzer, n, std::Real) @@ -104,38 +30,53 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <: n::Int distribution::D rng::R -end -function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG) - return NoiseAugmentation(analyzer, n, distribution::Sampleable, rng) + + function NoiseAugmentation( + analyzer::A, n::Int, distribution::D, rng::R + ) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} + n < 2 && + throw(ArgumentError("Number of noise samples `n` needs to be larger than one.")) + return new{A,D,R}(analyzer, n, distribution, rng) + end end function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real} return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng) end +function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG) + return NoiseAugmentation(analyzer, n, distribution, rng) +end function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...) # Regular forward pass of model output = aug.analyzer.model(input) output_indices = ns(output) - - # Call regular analyzer on augmented batch - augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng) - augmented_indices = augment_indices(output_indices, aug.n) - augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices)) + output_selector = AugmentationSelector(output_indices) + + # First augmentation + input_aug = similar(input) + input_aug = sample_noise!(input_aug, input, aug) + expl_aug = aug.analyzer(input_aug, output_selector) + sum_val = expl_aug.val + + # Further augmentations + for _ in 2:(aug.n) + input_aug = sample_noise!(input_aug, input, aug) + expl_aug = aug.analyzer(input_aug, output_selector) + sum_val += expl_aug.val + end # Average explanation + val = sum_val / aug.n + return Explanation( - reduce_augmentation(augmented_expl.val, aug.n), - input, - output, - output_indices, - augmented_expl.analyzer, - augmented_expl.heatmap, - nothing, + val, input, output, output_indices, expl_aug.analyzer, expl_aug.heatmap, nothing ) end -function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T} - return A + T.(rand(rng, distr, size(A))) +function sample_noise!( + out::A, input::A, aug::NoiseAugmentation +) where {T,A<:AbstractArray{T}} + out .= input .+ rand(aug.rng, aug.distribution, size(input)) end """ @@ -149,6 +90,13 @@ difference between the input and the reference input. struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod analyzer::A n::Int + + function InterpolationAugmentation(analyzer::A, n::Int) where {A<:AbstractXAIMethod} + n < 2 && throw( + ArgumentError("Number of interpolation steps `n` needs to be larger than one."), + ) + return new{A}(analyzer, n) + end end function call_analyzer( @@ -160,57 +108,25 @@ function call_analyzer( # Regular forward pass of model output = aug.analyzer.model(input) output_indices = ns(output) - - # Call regular analyzer on augmented batch - augmented_input = interpolate_batch(input, input_ref, aug.n) - augmented_indices = augment_indices(output_indices, aug.n) - augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices)) + output_selector = AugmentationSelector(output_indices) + + # First augmentations + input_aug = input_ref + expl_aug = aug.analyzer(input_aug, output_selector) + sum_val = expl_aug.val + + # Further augmentations + input_delta = (input - input_ref) / (aug.n - 1) + for _ in 1:(aug.n) + input_aug += input_delta + expl_aug = aug.analyzer(input_aug, output_selector) + sum_val += expl_aug.val + end # Average gradients and compute explanation - expl = (input - input_ref) .* reduce_augmentation(augmented_expl.val, aug.n) + val = (input - input_ref) .* sum_val / aug.n return Explanation( - expl, - input, - output, - output_indices, - augmented_expl.analyzer, - augmented_expl.heatmap, - nothing, + val, input, output, output_indices, expl_aug.analyzer, expl_aug.heatmap, nothing ) end - -""" - interpolate_batch(x, x0, nsamples) - -Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`. - -## Example -```julia-repl -julia> x = Float16.(reshape(1:4, 2, 2)) -2×2 Matrix{Float16}: - 1.0 3.0 - 2.0 4.0 - -julia> x0 = zero(x) -2×2 Matrix{Float16}: - 0.0 0.0 - 0.0 0.0 - -julia> interpolate_batch(x, x0, 5) -2×10 Matrix{Float16}: - 0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0 - 0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0 -``` -""" -function interpolate_batch( - x::AbstractArray{T,N}, x0::AbstractArray{T,N}, nsamples -) where {T,N} - in_size = size(x) - outs = similar(x, (in_size[1:(end - 1)]..., in_size[end] * nsamples)) - colons = ntuple(Returns(:), N - 1) - for (i, t) in enumerate(range(zero(T), oneunit(T); length=nsamples)) - outs[colons..., i:nsamples:end] .= x0 + t * (x - x0) - end - return outs -end diff --git a/test/references/cnn/IntegratedGradients_max.jld2 b/test/references/cnn/IntegratedGradients_max.jld2 index f78ec5a..4e3b68d 100644 Binary files a/test/references/cnn/IntegratedGradients_max.jld2 and b/test/references/cnn/IntegratedGradients_max.jld2 differ diff --git a/test/runtests.jl b/test/runtests.jl index 8bd5819..18c842c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,10 +21,6 @@ using JET end end - @testset "Input augmentation" begin - @info "Testing input augmentation..." - include("test_input_augmentation.jl") - end @testset "CNN" begin @info "Testing analyzers on CNN..." include("test_cnn.jl") diff --git a/test/test_input_augmentation.jl b/test/test_input_augmentation.jl deleted file mode 100644 index 78bdaa4..0000000 --- a/test/test_input_augmentation.jl +++ /dev/null @@ -1,52 +0,0 @@ -using ExplainableAI: augment_batch_dim, augment_indices, reduce_augmentation -using ExplainableAI: interpolate_batch -using Test - -# augment_batch_dim -A = [1 2; 3 4] -B = @inferred augment_batch_dim(A, 3) -@test B == [ - 1 1 1 2 2 2 - 3 3 3 4 4 4 -] -B = @inferred augment_batch_dim(A, 4) -@test B == [ - 1 1 1 1 2 2 2 2 - 3 3 3 3 4 4 4 4 -] -A = reshape(1:8, 2, 2, 2) -B = @inferred augment_batch_dim(A, 3) -@test B[:, :, 1] == A[:, :, 1] -@test B[:, :, 2] == A[:, :, 1] -@test B[:, :, 3] == A[:, :, 1] -@test B[:, :, 4] == A[:, :, 2] -@test B[:, :, 5] == A[:, :, 2] -@test B[:, :, 6] == A[:, :, 2] - -# augment_batch_dim -inds = [CartesianIndex(5, 1), CartesianIndex(3, 2)] -augmented_inds = @inferred augment_indices(inds, 3) -@test augmented_inds == [ - CartesianIndex(5, 1) - CartesianIndex(5, 2) - CartesianIndex(5, 3) - CartesianIndex(3, 4) - CartesianIndex(3, 5) - CartesianIndex(3, 6) -] - -# reduce_augmentation -A = Float32.(reshape(1:10, 1, 1, 10)) -R = @inferred reduce_augmentation(A, 5) -@test R == reshape([sum(1:5), sum(6:10)] / 5, 1, 1, :) -A = Float64.(reshape(1:10, 1, 1, 1, 1, 10)) -R = @inferred reduce_augmentation(A, 2) -@test R == reshape([3, 7, 11, 15, 19] / 2, 1, 1, 1, 1, :) - -x = Float16.(reshape(1:4, 2, 2)) -x0 = zero(x) -A = @inferred interpolate_batch(x, x0, 5) -@test A ≈ [ - 0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0 - 0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0 -]