Skip to content

Commit

Permalink
fix adapt for reinterpret(reshape, ...)
Browse files Browse the repository at this point in the history
This was added in JuliaLang/julia#37559 and is currently not handled correctly by Adapt.
  • Loading branch information
simeonschaub committed Jan 22, 2021
1 parent c7dad7f commit 567712d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Adapt"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.1.0"
version = "3.1.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
11 changes: 9 additions & 2 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ adapt_structure(to, A::PermutedDimsArray) =
PermutedDimsArray(adapt(to, Base.parent(A)), permutation(A))
adapt_structure(to, A::Base.ReshapedArray) =
Base.reshape(adapt(to, Base.parent(A)), size(A))
adapt_structure(to, A::Base.ReinterpretArray) =
Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A)))
@static if isdefined(Base, :NonReshapedReinterpretArray)
adapt_structure(to, A::Base.NonReshapedReinterpretArray) =
Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A)))
adapt_structure(to, A::Base.ReshapedReinterpretArray) =
Base.reinterpret(reshape, Base.eltype(A), adapt(to, Base.parent(A)))
else
adapt_structure(to, A::Base.ReinterpretArray) =
Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A)))
end

adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, Base.parent(A)))
Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ inds = CustomArray{Int,1}([1,2])

@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomArray

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomArray
end


## doubly-wrapped

Expand All @@ -129,6 +133,12 @@ inds = CustomArray{Int,1}([1,2])
@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomArray

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomArray
@test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomArray
end


using LinearAlgebra

Expand Down

0 comments on commit 567712d

Please sign in to comment.