diff --git a/Project.toml b/Project.toml index bf9874f..395a48a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Adapt" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.1.0" +version = "3.1.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/wrappers.jl b/src/wrappers.jl index 42051c9..b103569 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -15,8 +15,15 @@ adapt_structure(to, A::PermutedDimsArray) = PermutedDimsArray(adapt(to, Base.parent(A)), permutation(A)) adapt_structure(to, A::Base.ReshapedArray) = Base.reshape(adapt(to, Base.parent(A)), size(A)) -adapt_structure(to, A::Base.ReinterpretArray) = - Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A))) +@static if isdefined(Base, :NonReshapedReinterpretArray) + adapt_structure(to, A::Base.NonReshapedReinterpretArray) = + Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A))) + adapt_structure(to, A::Base.ReshapedReinterpretArray) = + Base.reinterpret(reshape, Base.eltype(A), adapt(to, Base.parent(A))) +else + adapt_structure(to, A::Base.ReinterpretArray) = + Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A))) +end adapt_structure(to, A::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(adapt(to, Base.parent(A))) diff --git a/test/runtests.jl b/test/runtests.jl index b860e6d..73ab0ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -113,6 +113,10 @@ inds = CustomArray{Int,1}([1,2]) @test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomArray +@static if isdefined(Base, :NonReshapedReinterpretArray) + @test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomArray +end + ## doubly-wrapped @@ -129,6 +133,12 @@ inds = CustomArray{Int,1}([1,2]) @test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomArray @test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomArray +@static if isdefined(Base, :NonReshapedReinterpretArray) + @test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomArray + @test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomArray + @test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomArray +end + using LinearAlgebra