Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant memory input augmentations #180

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 55 additions & 139 deletions src/input_augmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,80 +8,6 @@ struct AugmentationSelector{I} <: AbstractOutputSelector
end
(s::AugmentationSelector)(out) = s.indices

"""
augment_batch_dim(input, n)

Repeat each sample in input batch n-times along batch dimension.
This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`.

## Example
```julia-repl
julia> A = [1 2; 3 4]
2×2 Matrix{Int64}:
1 2
3 4

julia> augment_batch_dim(A, 3)
2×6 Matrix{Int64}:
1 1 1 2 2 2
3 3 3 4 4 4
```
"""
function augment_batch_dim(input::AbstractArray{T,N}, n) where {T,N}
return repeat(input; inner=(ntuple(Returns(1), N - 1)..., n))
end

"""
reduce_augmentation(augmented_input, n)

Reduce augmented input batch by averaging the explanation for each augmented sample.
"""
function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFloat,N}
# Allocate output array
in_size = size(input)
in_size[end] % n != 0 &&
throw(ArgumentError("Can't reduce augmented batch size of $(in_size[end]) by $n"))
out_size = (in_size[1:(end - 1)]..., div(in_size[end], n))
out = similar(input, eltype(input), out_size)

axs = axes(input, N)
colons = ntuple(Returns(:), N - 1)
for (i, ax) in enumerate(first(axs):n:last(axs))
view(out, colons..., i) .=
dropdims(sum(view(input, colons..., ax:(ax + n - 1)); dims=N); dims=N) / n
end
return out
end

"""
augment_indices(indices, n)

Strip batch indices and return indices for batch augmented by n samples.

## Example
```julia-repl
julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)]
2-element Vector{CartesianIndex{2}}:
CartesianIndex(5, 1)
CartesianIndex(3, 2)

julia> augment_indices(inds, 3)
6-element Vector{CartesianIndex{2}}:
CartesianIndex(5, 1)
CartesianIndex(5, 2)
CartesianIndex(5, 3)
CartesianIndex(3, 4)
CartesianIndex(3, 5)
CartesianIndex(3, 6)
```
"""
function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
indices_wo_batch = [i.I[1:(end - 1)] for i in inds]
return map(enumerate(repeat(indices_wo_batch; inner=n))) do (i, idx)
CartesianIndex{N}(idx..., i)
end
end

"""
NoiseAugmentation(analyzer, n)
NoiseAugmentation(analyzer, n, std::Real)
Expand All @@ -104,38 +30,53 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
n::Int
distribution::D
rng::R
end
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
return NoiseAugmentation(analyzer, n, distribution::Sampleable, rng)

function NoiseAugmentation(
analyzer::A, n::Int, distribution::D, rng::R
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
n < 2 &&
throw(ArgumentError("Number of noise samples `n` needs to be larger than one."))
return new{A,D,R}(analyzer, n, distribution, rng)
end
end
function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng)
end
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
return NoiseAugmentation(analyzer, n, distribution, rng)
end

function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
# Regular forward pass of model
output = aug.analyzer.model(input)
output_indices = ns(output)

# Call regular analyzer on augmented batch
augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng)
augmented_indices = augment_indices(output_indices, aug.n)
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
output_selector = AugmentationSelector(output_indices)

# First augmentation
input_aug = similar(input)
input_aug = sample_noise!(input_aug, input, aug)
expl_aug = aug.analyzer(input_aug, output_selector)
sum_val = expl_aug.val

# Further augmentations
for _ in 2:(aug.n)
input_aug = sample_noise!(input_aug, input, aug)
expl_aug = aug.analyzer(input_aug, output_selector)
sum_val += expl_aug.val
end

# Average explanation
val = sum_val / aug.n

return Explanation(
reduce_augmentation(augmented_expl.val, aug.n),
input,
output,
output_indices,
augmented_expl.analyzer,
augmented_expl.heatmap,
nothing,
val, input, output, output_indices, expl_aug.analyzer, expl_aug.heatmap, nothing
)
end

function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T}
return A + T.(rand(rng, distr, size(A)))
function sample_noise!(
out::A, input::A, aug::NoiseAugmentation
) where {T,A<:AbstractArray{T}}
out .= input .+ rand(aug.rng, aug.distribution, size(input))
end

"""
Expand All @@ -149,6 +90,13 @@ difference between the input and the reference input.
struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
analyzer::A
n::Int

function InterpolationAugmentation(analyzer::A, n::Int) where {A<:AbstractXAIMethod}
n < 2 && throw(
ArgumentError("Number of interpolation steps `n` needs to be larger than one."),
)
return new{A}(analyzer, n)
end
end

function call_analyzer(
Expand All @@ -160,57 +108,25 @@ function call_analyzer(
# Regular forward pass of model
output = aug.analyzer.model(input)
output_indices = ns(output)

# Call regular analyzer on augmented batch
augmented_input = interpolate_batch(input, input_ref, aug.n)
augmented_indices = augment_indices(output_indices, aug.n)
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
output_selector = AugmentationSelector(output_indices)

# First augmentations
input_aug = input_ref
expl_aug = aug.analyzer(input_aug, output_selector)
sum_val = expl_aug.val

# Further augmentations
input_delta = (input - input_ref) / (aug.n - 1)
for _ in 1:(aug.n)
input_aug += input_delta
expl_aug = aug.analyzer(input_aug, output_selector)
sum_val += expl_aug.val
end

# Average gradients and compute explanation
expl = (input - input_ref) .* reduce_augmentation(augmented_expl.val, aug.n)
val = (input - input_ref) .* sum_val / aug.n

return Explanation(
expl,
input,
output,
output_indices,
augmented_expl.analyzer,
augmented_expl.heatmap,
nothing,
val, input, output, output_indices, expl_aug.analyzer, expl_aug.heatmap, nothing
)
end

"""
interpolate_batch(x, x0, nsamples)

Augment batch along batch dimension using linear interpolation between input `x` and a reference input `x0`.

## Example
```julia-repl
julia> x = Float16.(reshape(1:4, 2, 2))
2×2 Matrix{Float16}:
1.0 3.0
2.0 4.0

julia> x0 = zero(x)
2×2 Matrix{Float16}:
0.0 0.0
0.0 0.0

julia> interpolate_batch(x, x0, 5)
2×10 Matrix{Float16}:
0.0 0.25 0.5 0.75 1.0 0.0 0.75 1.5 2.25 3.0
0.0 0.5 1.0 1.5 2.0 0.0 1.0 2.0 3.0 4.0
```
"""
function interpolate_batch(
x::AbstractArray{T,N}, x0::AbstractArray{T,N}, nsamples
) where {T,N}
in_size = size(x)
outs = similar(x, (in_size[1:(end - 1)]..., in_size[end] * nsamples))
colons = ntuple(Returns(:), N - 1)
for (i, t) in enumerate(range(zero(T), oneunit(T); length=nsamples))
outs[colons..., i:nsamples:end] .= x0 + t * (x - x0)
end
return outs
end
Binary file modified test/references/cnn/IntegratedGradients_max.jld2
Binary file not shown.
4 changes: 0 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ using JET
end
end

@testset "Input augmentation" begin
@info "Testing input augmentation..."
include("test_input_augmentation.jl")
end
@testset "CNN" begin
@info "Testing analyzers on CNN..."
include("test_cnn.jl")
Expand Down
52 changes: 0 additions & 52 deletions test/test_input_augmentation.jl

This file was deleted.

Loading