Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image rotation #553

Merged
merged 32 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ee40dec
Add dependencies
roflmaostc Dec 17, 2023
7ac1f01
Add code
roflmaostc Dec 17, 2023
b130c93
Add test
roflmaostc Dec 17, 2023
5a2ea46
clear docs
roflmaostc Dec 17, 2023
76479bc
Fix FiniteDifferences.fdm_central
roflmaostc Dec 17, 2023
9c1bb4c
Include Refactoring
roflmaostc Jan 13, 2024
94f44d6
Code cleaning and fix adjoint for trivial rotations
roflmaostc Jan 13, 2024
158e263
Code cleaning
roflmaostc Jan 13, 2024
63022c6
Fix bug with even and odd arrays and trivial rotations
roflmaostc Jan 13, 2024
2fbc467
First parts of test are generalized, not gradients yet
roflmaostc Jan 13, 2024
240abea
Add gradtests, some fail
roflmaostc Jan 13, 2024
80b3baf
Tests working and subtle bug fixed for trivial rotations
roflmaostc Jan 13, 2024
cee44f4
Fix space before && [skip ci]
roflmaostc Jan 13, 2024
51295e5
Add documentation and fix issue with FillArrays
roflmaostc Jan 13, 2024
7b3ccb9
Fix function name
roflmaostc Jan 15, 2024
101cbf4
Fix size(arr)
roflmaostc Jan 15, 2024
9e7e4cb
Relax some tests since they failed on CUDA
roflmaostc Jan 15, 2024
aece567
Test with rel error of 1f-2
roflmaostc Jan 15, 2024
8923d49
Refine rotation tests
roflmaostc Jan 15, 2024
c245a56
Introduce show statement for buildkite
roflmaostc Jan 15, 2024
761534e
Remove show statement, introduce even test case again
roflmaostc Jan 16, 2024
6b5a5be
Rename midpoint to rotation_center and change rounding
roflmaostc Jan 16, 2024
9ce9cdb
Add more tests, nearest neighbour fails sometimes
roflmaostc Jan 16, 2024
1fb1e27
Lower tolerance tests
roflmaostc Jan 16, 2024
ae3f45e
Proper error handling
roflmaostc Jan 17, 2024
e12310e
Add underscore _ to internal methods. Clean docs
roflmaostc Jan 17, 2024
a99aa18
Change to Float64 test
roflmaostc Jan 18, 2024
fcba1c3
Lower testing accuracy for Float64
roflmaostc Jan 21, 2024
b9d92f8
Revert "Lower testing accuracy for Float64"
roflmaostc Jan 21, 2024
b4413ba
Rerun CI
roflmaostc Jan 29, 2024
2677391
Fix typo, rerun CI
roflmaostc Jan 29, 2024
b666ab0
Improve docstring [skip ci]
roflmaostc Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -62,4 +64,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils"]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils", "Interpolations", "ImageTransformations"]
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,7 @@ include("impl/depthwiseconv_im2col.jl")
include("impl/pooling_direct.jl")
include("deprecations.jl")

include("rotation.jl")
export imrotate

end # module NNlib
250 changes: 250 additions & 0 deletions src/rotation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
rotate_coordinates(sinθ, cosθ, i, j, midpoint, round_or_floor)

this rotates the coordinates and either applies round(nearest neighbour)
or floor for :bilinear interpolation)
"""
@inline function rotate_coordinates(sinθ, cosθ, i, j, midpoint, round_or_floor)
y = i - midpoint[1]
x = j - midpoint[2]
yrot = cosθ * y - sinθ * x + midpoint[1]
xrot = sinθ * y + cosθ * x + midpoint[2]
yrot_f = round_or_floor(yrot)
xrot_f = round_or_floor(xrot)
yrot_int = round_or_floor(Int, yrot)
xrot_int = round_or_floor(Int, xrot)
return yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int
end


"""
bilinear_helper(yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int)

Some helper variables
"""
@inline function bilinear_helper(yrot, xrot, yrot_f, xrot_f)
xdiff = (xrot - xrot_f)
xdiff_1minus = 1 - xdiff
ydiff = (yrot - yrot_f)
ydiff_1minus = 1 - ydiff

return ydiff, ydiff_1minus, xdiff, xdiff_1minus
end


"""
_prepare_imrotate(arr, θ, midpoint)

Prepate `sin` and `cos`, creates the output array and converts type
of `midpoint` if required.
"""
function _prepare_imrotate(arr::AbstractArray{T}, θ, midpoint) where T
# needed for rotation matrix
θ = mod(real(T)(θ), real(T)(2π))
midpoint = real(T).(midpoint)
sinθ, cosθ = sincos(real(T)(θ))
out = similar(arr)
fill!(out, 0)
return sinθ, cosθ, midpoint, out
end


"""
_check_trivial_rotations!(out, arr, θ, midpoint)

When `θ = 0 || π /2 || π || 3/2 || π` and if `midpoint`
is in the middle of the array.
For an even array of size 4, the midpoint would need to be 2.5.
For an odd array of size 5, the midpoint would need to be 3.

In those cases, rotations are trivial just by reversing or swapping some axes.
"""
function _check_trivial_rotations!(out, arr, θ, midpoint; adjoint=false)
if iszero(θ)
out .= arr
return true
end
# check for special cases where rotations are trivial
if (iseven(size(arr, 1)) && iseven(size(arr, 2)) &&
midpoint[1] ≈ size(arr, 1) ÷ 2 + 0.5 && midpoint[2] ≈ size(arr, 2) ÷ 2 + 0.5) ||
(isodd(size(arr, 1)) && isodd(size(arr, 2)) &&
(midpoint[1] == size(arr, 1) ÷ 2 + 1 && midpoint[1] == size(arr, 2) ÷ 2 + 1))
if θ ≈ π / 2
if adjoint == false
out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,))
else
out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,))
end
return true
elseif θ ≈ π
out .= reverse(arr, dims=(1,2))
return true
elseif θ ≈ 3 / 2 * π
if adjoint == false
out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,))
else
out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,))
end
return true
end
end

return false
end

function _∇check_trivial_rotations!(out, arr, θ, midpoint)
if iszero(θ)
out .= arr
return true
end
# check for special cases where rotations are trivial
if (iseven(size(arr, 1)) && iseven(size(arr, 2)) &&
midpoint[1] ≈ size(arr, 1) ÷ 2 + 0.5 && midpoint[2] ≈ size(arr, 2) ÷ 2 + 0.5) ||
(isodd(size(arr, 1)) && isodd(size(arr, 2)) &&
(midpoint[1] == size(arr, 1) ÷ 2 + 1 && midpoint[1] == size(arr, 2) ÷ 2 + 1))
if θ ≈ π / 2
out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,))
return true
elseif θ ≈ π
out .= reverse(arr, dims=(1,2))
return true
elseif θ ≈ 3 / 2 * π
out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,))
return true
end
end

return false
end

"""
imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, midpoint=size(arr) .÷ 2 .+ 1)

roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
Rotates a matrix around the center pixel `midpoint`.
The angle `θ` is interpreted in radians.

The adjoint is defined with ChainRulesCore.jl. This method also runs with CUDA (and in principle all KernelAbstractions.jl supported backends).

# Keywords
* `method=:bilinear` for bilinear interpolation or `method=:nearest` for nearest neighbour
* `midpoint=size(arr) .÷ 2 .+ 1` means there is always a real center pixel around it is rotated.

# Examples
```julia-repl

```
"""
function imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, midpoint=size(arr) .÷ 2 .+ 1) where T
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
@assert (T <: Integer && method==:nearest || !(T <: Integer)) "If the array has an Int eltype, only method=:nearest is supported"
@assert typeof(midpoint) <: Tuple "midpoint keyword has to be a tuple"

# prepare out, the sin and cos and type of midpoint
sinθ, cosθ, midpoint, out = _prepare_imrotate(arr, θ, midpoint)
# such as 0°, 90°, 180°, 270° and only if the midpoint is suitable
_check_trivial_rotations!(out, arr, θ, midpoint) && return out

# KernelAbstractions specific
backend = KernelAbstractions.get_backend(arr)
if method == :bilinear
kernel! = imrotate_kernel_bilinear!(backend)
elseif method == :nearest
kernel! = imrotate_kernel_nearest!(backend)
else
throw(ArgumentError("No interpolation method such as $method"))
end
kernel!(out, arr, sinθ, cosθ, midpoint, size(arr, 1), size(arr, 2),
ndrange=(size(arr, 1), size(arr, 2), size(arr, 3), size(arr, 4)))
return out
end

"""
∇imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear,

Adjoint for `imrotate`. Gradient only with respect to `arr` and not `θ`.
"""
function ∇imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear,
midpoint=size(arr) .÷ 2 .+ 1) where T

sinθ, cosθ, midpoint, out = _prepare_imrotate(arr, θ, midpoint)
# for the adjoint, the trivial rotations go in the other direction!
_check_trivial_rotations!(out, arr, θ, midpoint, adjoint=true) && return out

backend = KernelAbstractions.get_backend(arr)
if method == :bilinear
kernel! = ∇imrotate_kernel_bilinear!(backend)
elseif method == :nearest
kernel! = ∇imrotate_kernel_nearest!(backend)
else
throw(ArgumentError("No interpolation method such as $method"))
end
kernel!(out, arr, sinθ, cosθ, midpoint, size(arr, 1), size(arr, 2),
ndrange=(size(arr, 1), size(arr, 2), size(arr, 3), size(arr, 4)))
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
return out
end


@kernel function imrotate_kernel_nearest!(out, arr, sinθ, cosθ, midpoint, imax, jmax)
i, j, c, b = @index(Global, NTuple)

_, _, _, _, yrot_int, xrot_int = rotate_coordinates(sinθ, cosθ, i, j, midpoint, round)
if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax
@inbounds out[i, j, c, b] = arr[yrot_int, xrot_int, c, b]
end
end


@kernel function imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, midpoint, imax, jmax)
i, j, c, b = @index(Global, NTuple)

yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = rotate_coordinates(sinθ, cosθ, i, j, midpoint, floor)
if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1

ydiff, ydiff_1minus, xdiff, xdiff_1minus =
bilinear_helper(yrot, xrot, yrot_f, xrot_f)
@inbounds out[i, j, c, b] =
( xdiff_1minus * ydiff_1minus * arr[yrot_int , xrot_int , c, b]
+ xdiff_1minus * ydiff * arr[yrot_int + 1 , xrot_int , c, b]
+ xdiff * ydiff_1minus * arr[yrot_int , xrot_int + 1 , c, b]
+ xdiff * ydiff * arr[yrot_int + 1 , xrot_int + 1 , c, b])
end
end


@kernel function ∇imrotate_kernel_nearest!(out, arr, sinθ, cosθ, midpoint, imax, jmax)
i, j, c, b = @index(Global, NTuple)

_, _, _, _, yrot_int, xrot_int = rotate_coordinates(sinθ, cosθ, i, j, midpoint, round)
if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax
Atomix.@atomic out[yrot_int, xrot_int, c, b] += arr[i, j, c, b]
end
end


@kernel function ∇imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, midpoint, imax, jmax)
i, j, c, b = @index(Global, NTuple)

yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = rotate_coordinates(sinθ, cosθ, i, j, midpoint, floor)
if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1
o = arr[i, j, c, b]
ydiff, ydiff_1minus, xdiff, xdiff_1minus =
bilinear_helper(yrot, xrot, yrot_f, xrot_f)
Atomix.@atomic out[yrot_int , xrot_int , c, b] += xdiff_1minus * ydiff_1minus * o
Atomix.@atomic out[yrot_int + 1 , xrot_int , c, b] += xdiff_1minus * ydiff * o
Atomix.@atomic out[yrot_int , xrot_int + 1, c, b] += xdiff * ydiff_1minus * o
Atomix.@atomic out[yrot_int + 1 , xrot_int + 1, c, b] += xdiff * ydiff * o
end
end


# is this rrule good?
# no @thunk and @unthunk
function ChainRulesCore.rrule(::typeof(imrotate), array::AbstractArray{T}, θ;
method=:bilinear, midpoint=size(array) .÷ 2 .+ 1) where T
res = imrotate(array, θ; method, midpoint)
function pb_rotate(dy)
ad = ∇imrotate(unthunk(collect(dy)), θ; method, midpoint)
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
return NoTangent(), ad, NoTangent()
end

return res, pb_rotate
end
90 changes: 90 additions & 0 deletions test/rotation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
function upsample_testsuite(Backend)
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
device(x) = adapt(Backend(), x)
gradtest_fn = Backend == CPU ? gradtest : gputest
T = Float32
atol = T == Float32 ? 1e-3 : 1e-6

@testset "Image Rotation" begin
@testset "Simple test" begin
arr = device(zeros((6, 6, 1, 1)));
arr[3:4, 4, 1, 1] .= 1;
@test all(cpu(NNlib.imrotate(arr, deg2rad(45))) .≈ [0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.29289321881345254 0.585786437626905 0.0; 0.0 0.0 0.08578643762690495 1.0 0.2928932188134524 0.0; 0.0 0.0 0.0 0.08578643762690495 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0])
end
end


@testset "Compare with ImageTransformations" begin
arr = device(zeros(T, (51, 51, 1, 1)))
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
arr[15:40, 15:40, :, :] .= device(1 .+ randn((26, 26)))

arr2 = device(zeros(T, (51, 51, 1, 5)))
arr2[15:40, 15:40, :, :] .= device(arr[15:40, 15:40, :, :])


for method in [:nearest, :bilinear]
for angle in deg2rad.([0, 35, 45, 90, 135, 170, 180, 270, 360])
res1 = cpu(NNlib.imrotate(arr, angle; method))
res3 = cpu(NNlib.imrotate(arr2, angle; method))
if method == :nearest
res2 = ImageTransformations.imrotate(cpu(arr)[:, :, 1, 1], angle, axes(arr)[1:2], method=Constant(), fillvalue=0)
elseif method == :bilinear
res2 = ImageTransformations.imrotate(cpu(arr)[:, :, 1, 1], angle, axes(arr)[1:2], fillvalue=0)
end
@test all(1 .+ res1[:, :, :, :] .≈ 1 .+ res2[:, :])
@test all(1 .+ res1[:, :, :, :] .≈ 1 .+ res3[:, :,:, 1])
@test all(1 .+ res1[:, :, :, :] .≈ 1 .+ res3[:, :,:, 2])
end
end

arr = device(zeros(T, (52, 52, 1, 1)))
arr[15:40, 15:40, :, :] .= device(1 .+ randn((26, 26)))

arr2 = device(zeros(T, (52, 52, 5, 1)))
arr2[15:40, 15:40, :, 1] .= device(arr[15:40, 15:40, :, 1])

for method in [:nearest, :bilinear]
for angle in deg2rad.([0, 35, 90, 170, 180, 270, 360])
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
res1 = cpu(NNlib.imrotate(arr, angle; method, midpoint=size(arr) .÷2 .+0.5))
res3 = cpu(NNlib.imrotate(arr2, angle; method, midpoint=size(arr) .÷2 .+0.5))
if method == :nearest
res2 = ImageTransformations.imrotate(cpu(arr)[:, :, 1, 1], angle, axes(arr)[1:2], method=Constant(), fillvalue=0)
elseif method == :bilinear
res2 = ImageTransformations.imrotate(cpu(arr)[:, :, 1, 1], angle, axes(arr)[1:2], fillvalue=0)
end
@test all(1 .+ res1[:, :, :, :] .≈ 1 .+ res2[:, :])
@test all(1 .+ res1[:, :, :, :] .≈ 1 .+ res3[:, :, :, :])
@test all(1 .+ res1[:, :, :, :] .≈ 1 .+ res3[:, :, :, :])
end
end

end

@testset "Compare for plausibilty" begin
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
arr = zeros(T, (10, 10, 1, 3))
arr[6, 6, :, 1] .= 1
arr[6, 6, :, 2] .= 2
arr[6, 6, :, 3] .= 2

for method in [:bilinear, :nearest]
@test all(.≈(arr , NNlib.imrotate(arr, deg2rad(0); method)))
@test all(.≈(arr , NNlib.imrotate(arr, deg2rad(90); method)))
@test all(.≈(arr , NNlib.imrotate(arr, deg2rad(180); method)))
@test all(.≈(arr , NNlib.imrotate(arr, deg2rad(270); method)))
@test all(.≈(arr , NNlib.imrotate(arr, deg2rad(360); method)))
end
end


@testset "Test gradients" begin
for method in [:nearest, :bilinear]
for angle in deg2rad.([0, 35, 90, 170, 180, 270, 360])
gradtest_fn(
x -> NNlib.imrotate(x, angle; method),
device(rand(T, 11,11,1,1)); atol)
gradtest_fn(
x -> NNlib.imrotate(x, angle; method),
device(rand(T, 10,10,1,1)); atol)
end
end
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using Zygote: gradient
using StableRNGs
using Documenter
using Adapt
using ImageTransformations
using Interpolations: Constant
using KernelAbstractions
import ReverseDiff as RD # used in `pooling.jl`

Expand Down Expand Up @@ -43,11 +45,15 @@ cpu(x) = adapt(CPU(), x)
include("gather.jl")
include("scatter.jl")
include("upsample.jl")
include("rotation.jl")

function nnlib_testsuite(Backend; skip_tests = Set{String}())
@conditional_testset "Upsample" skip_tests begin
upsample_testsuite(Backend)
end
@conditional_testset "rotation" skip_tests begin
upsample_testsuite(Backend)
end
@conditional_testset "Gather" skip_tests begin
gather_testsuite(Backend)
end
Expand Down