Skip to content

Commit

Permalink
Return float(T)[] so that refine_grid should be type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
terasakisatoshi committed Dec 12, 2024
1 parent 96bbdec commit eee06e8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/_roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ closeenough(a::T, b::T, ϵ) where {T<:AbstractFloat} = isapprox(a, b; rtol=0, at
closeenough(a::T, b::T, _) where {T<:Integer} = a == b

function refine_grid(grid::Vector{T}, ::Val{α}) where {T, α}
isempty(grid) && return grid
isempty(grid) && return float(T)[]
n = length(grid)
newn = α * (n - 1) + 1
newgrid = Vector{float(T)}(undef, newn)
Expand Down
16 changes: 10 additions & 6 deletions test/_roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseIR
@testset "Basic refinement" begin
# Test with α = 2
grid = [0.0, 1.0, 2.0]
refined = SparseIR.refine_grid(grid, Val(2))
refined = @inferred SparseIR.refine_grid(grid, Val(2))
@test length(refined) == 5 # α * (n-1) + 1 = 2 * (3-1) + 1 = 5
@test refined [0.0, 0.5, 1.0, 1.5, 2.0]

Expand All @@ -29,13 +29,13 @@ using SparseIR
@testset "Type stability" begin
# Integer grid
int_grid = [0, 1, 2]
refined_int = SparseIR.refine_grid(int_grid, Val(2))
@test eltype(refined_int) === Float64
@test refined_int == [0.0, 0.5, 1.0, 1.5, 2.0]
refined = @inferred SparseIR.refine_grid(int_grid, Val(2))
@test eltype(refined) === Float64
@test refined == [0.0, 0.5, 1.0, 1.5, 2.0]

# Float32 grid
f32_grid = Float32[0, 1, 2]
refined_f32 = SparseIR.refine_grid(f32_grid, Val(2))
refined_f32 = @inferred SparseIR.refine_grid(f32_grid, Val(2))
@test eltype(refined_f32) === Float32
@test refined_f32 Float32[0, 0.5, 1, 1.5, 2]
end
Expand All @@ -50,7 +50,11 @@ using SparseIR
# Empty grid
empty_grid = Float64[]
@test isempty(SparseIR.refine_grid(empty_grid, Val(2)))

# Empty grid
empty_grid = Int[]
out_grid = SparseIR.refine_grid(empty_grid, Val(2))
@test isempty(out_grid)
@test eltype(out_grid) === Float64
# Single point
single_point = [1.0]
@test SparseIR.refine_grid(single_point, Val(2)) == [1.0]
Expand Down

0 comments on commit eee06e8

Please sign in to comment.