From deafb750ccab044ee9b360248b4b847e07c981fe Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Mon, 13 May 2024 15:57:49 +0200 Subject: [PATCH] fixes to colonful `reshape` (#54261) Fixes #54245 --- base/reshapedarray.jl | 20 ++++++++++++++++---- test/abstractarray.jl | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index 6cf2b9b4820165..501af97715c746 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -132,12 +132,24 @@ reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = reshape( "may have at most one omitted dimension specified by `Colon()`"))) @noinline throw2(A, dims) = throw(DimensionMismatch(string("array size $(length(A)) ", "must be divisible by the product of the new dimensions $dims"))) - pre = _before_colon(dims...) + pre = _before_colon(dims...)::Tuple{Vararg{Int}} post = _after_colon(dims...) _any_colon(post...) && throw1(dims) - sz, remainder = divrem(length(A), prod(pre)*prod(post)) - remainder == 0 || throw2(A, dims) - (pre..., Int(sz), post...) + post::Tuple{Vararg{Int}} + len = length(A) + sz, is_exact = if iszero(len) + (0, true) + else + let pr = Core.checked_dims(pre..., post...) # safe product + if iszero(pr) + throw2(A, dims) + end + (quo, rem) = divrem(len, pr) + (Int(quo), iszero(rem)) + end + end::Tuple{Int,Bool} + is_exact || throw2(A, dims) + (pre..., sz, post...)::Tuple{Int,Vararg{Int}} end @inline _any_colon() = false @inline _any_colon(dim::Colon, tail...) = true diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 4f0559f2fa89fa..aa173db63b11ca 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1852,6 +1852,24 @@ function check_pointer_strides(A::AbstractArray) return true end +@testset "colonful `reshape`, #54245" begin + @test reshape([], (0, :)) isa Matrix + @test_throws DimensionMismatch reshape([7], (0, :)) + let b = prevpow(2, typemax(Int)) + @test iszero(b*b) + @test_throws ArgumentError reshape([7], (b, :, b)) + @test reshape([], (b, :, b)) isa Array{<:Any, 3} + end + for iterator ∈ (7:6, 7:7, 7:8) + for it ∈ (iterator, map(BigInt, iterator)) + @test reshape(it, (:, Int(length(it)))) isa AbstractMatrix + @test reshape(it, (Int(length(it)), :)) isa AbstractMatrix + @test reshape(it, (1, :)) isa AbstractMatrix + @test reshape(it, (:, 1)) isa AbstractMatrix + end + end +end + @testset "strides for ReshapedArray" begin # Type-based contiguous Check a = vec(reinterpret(reshape, Int16, reshape(view(reinterpret(Int32, randn(10)), 2:11), 5, :)))