Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Compiler use separate sorting algorithm. #47066

Merged
merged 9 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,10 @@ import Core.Compiler.CoreDocs
Core.atdoc!(CoreDocs.docm)

# sorting
function sort! end
function issorted end
include("ordering.jl")
using .Order
include("sort.jl")
using .Sort
include("compiler/sort.jl")

# We don't include some.jl, but this definition is still useful.
something(x::Nothing, y...) = something(y...)
Expand Down
100 changes: 100 additions & 0 deletions base/compiler/sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# reference on sorted binary search:
# http://www.tbray.org/ongoing/When/200x/2003/03/22/Binary

# index of the first value of vector a that is greater than or equal to x;
# returns lastindex(v)+1 if x is greater than all values in v.
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
hi = hi + T(1)
len = hi - lo
@inbounds while len != 0
half_len = len >>> 0x01
m = lo + half_len
if lt(o, v[m], x)
lo = m + 1
len -= half_len + 1
else
hi = m
len = half_len
end
end
return lo
end

# index of the last value of vector a that is less than or equal to x;
# returns firstindex(v)-1 if x is less than all values of v.
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, x, v[m])
hi = m
else
lo = m
end
end
return lo
end

# returns the range of indices of v equal to x
# if v does not contain x, returns a 0-length range
# indicating the insertion point of x
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRange{keytype(v)} where T<:Integer
u = T(1)
lo = ilo - u
hi = ihi + u
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, v[m], x)
lo = m
elseif lt(o, x, v[m])
hi = m
else
a = searchsortedfirst(v, x, max(lo,ilo), m, o)
b = searchsortedlast(v, x, m, min(hi,ihi), o)
return a : b
end
end
return (lo + 1) : (hi - 1)
end

for s in [:searchsortedfirst, :searchsortedlast, :searchsorted]
@eval begin
$s(v::AbstractVector, x, o::Ordering) = $s(v,x,firstindex(v),lastindex(v),o)
$s(v::AbstractVector, x;
lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, order::Ordering=Forward) =
$s(v,x,ord(lt,by,rev,order))
end
end

# An unstable sorting algorithm for internal use
function sort!(v::Vector; by::Function=identity, (<)::Function=<)
isempty(v) && return v # This branch is hit 95% of the time

# Of the remaining 5%, this branch is hit less than 1% of the time
if length(v) > 200 # Heap sort prevents quadratic runtime
o = ord(<, by, true)
heapify!(v, o)
for i in lastindex(v):-1:2
y = v[i]
v[i] = v[1]
percolate_down!(v, 1, y, o, i-1)
end
return v
end

@inbounds for i in 2:length(v) # Insertion sort
x = v[i]
y = by(x)
while i > 1 && y < by(v[i-1])
v[i] = v[i-1]
i -= 1
end
v[i] = x
end

v
end
21 changes: 6 additions & 15 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,15 +536,6 @@ end
insert_node!(ir::IRCode, pos::Int, newinst::NewInstruction, attach_after::Bool=false) =
insert_node!(ir, SSAValue(pos), newinst, attach_after)

# For bootstrapping
function my_sortperm(v)
p = Vector{Int}(undef, length(v))
for i = 1:length(v)
p[i] = i
end
sort!(p, Sort.DEFAULT_UNSTABLE, Order.Perm(Sort.Forward,v))
p
end

mutable struct IncrementalCompact
ir::IRCode
Expand Down Expand Up @@ -576,10 +567,9 @@ mutable struct IncrementalCompact

function IncrementalCompact(code::IRCode, allow_cfg_transforms::Bool=false)
# Sort by position with attach after nodes after regular ones
perm = my_sortperm(Int[let new_node = code.new_nodes.info[i]
(new_node.pos * 2 + Int(new_node.attach_after))
end for i in 1:length(code.new_nodes)])
new_len = length(code.stmts) + length(code.new_nodes)
info = code.new_nodes.info
perm = sort!(collect(eachindex(info)); by=i->(2info[i].pos+info[i].attach_after, i))
new_len = length(code.stmts) + length(info)
result = InstructionStream(new_len)
used_ssas = fill(0, new_len)
new_new_used_ssas = Vector{Int}()
Expand Down Expand Up @@ -631,8 +621,9 @@ mutable struct IncrementalCompact

# For inlining
function IncrementalCompact(parent::IncrementalCompact, code::IRCode, result_offset)
perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)])
new_len = length(code.stmts) + length(code.new_nodes)
info = code.new_nodes.info
perm = sort!(collect(eachindex(info)); by=i->(info[i].pos, i))
new_len = length(code.stmts) + length(info)
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
bb_rename = Vector{Int}()
pending_nodes = NewNodeStream()
Expand Down
2 changes: 1 addition & 1 deletion test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function choosetests(choices = [])
filtertests!(tests, "subarray")
filtertests!(tests, "compiler", [
"compiler/datastructures", "compiler/inference", "compiler/effects",
"compiler/validation", "compiler/ssair", "compiler/irpasses",
"compiler/validation", "compiler/sort", "compiler/ssair", "compiler/irpasses",
"compiler/codegen", "compiler/inline", "compiler/contextual",
"compiler/AbstractInterpreter", "compiler/EscapeAnalysis/local",
"compiler/EscapeAnalysis/interprocedural"])
Expand Down
12 changes: 12 additions & 0 deletions test/compiler/datastructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,15 @@ end
end
end
end

# Make sure that the compiler can sort things.
# https://github.com/JuliaLang/julia/issues/47065
@testset "Compiler Sorting" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is redundant with (or at least should be in the same place as) test/compiler/sort.jl

for len in (0, 1, 10, 100, 10000)
v = Core.Compiler.sort!(rand(Int8,len))
@test length(v) == len
@test issorted(v)
Core.Compiler.sort!(v, by=abs)
@test issorted(v, by=abs)
end
end
16 changes: 16 additions & 0 deletions test/compiler/interpreter_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,19 @@ let m = Meta.@lower 1 + 1
global test29262 = false
@test :b === @eval $m
end

@testset "many basic blocks" begin
n = 1000
ex = :(return 1)
for _ in 1:n
ex = :(if rand()<.1
$(ex) end)
end
@eval begin
function f_1000()
$ex
return 0
end
end
@test f_1000()===0
end
Comment on lines +110 to +124
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused at this test case. What's being tested here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this is testing compilation of a function with a bunch of basic blocks. Is there a better version of this test?

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the compilation performance being tested, or is there any correctness issue? And why f_1000 always returns 0?

Copy link
Member Author

@oscardssmith oscardssmith Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a correctness issue (#47065) where the compiler would try to sort a big list and fail. f_1000 doesn't technically always return 0, but it returns 1 1 in 10^1000 times which is close enough to never that it will never fail.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I understand. I'd prefer not having @testset there though.

44 changes: 44 additions & 0 deletions test/compiler/sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
@testset "searchsorted" begin
@test Core.Compiler.searchsorted([1, 1, 2, 2, 3, 3], 0) === Core.Compiler.UnitRange(1, 0)
@test Core.Compiler.searchsorted([1, 1, 2, 2, 3, 3], 1) === Core.Compiler.UnitRange(1, 2)
@test Core.Compiler.searchsorted([1, 1, 2, 2, 3, 3], 2) === Core.Compiler.UnitRange(3, 4)
@test Core.Compiler.searchsorted([1, 1, 2, 2, 3, 3], 4) === Core.Compiler.UnitRange(7, 6)
@test Core.Compiler.searchsorted([1, 1, 2, 2, 3, 3], 2.5; lt=<) === Core.Compiler.UnitRange(5, 4)

@test Core.Compiler.searchsorted(Core.Compiler.UnitRange(1, 3), 0) === Core.Compiler.UnitRange(1, 0)
@test Core.Compiler.searchsorted(Core.Compiler.UnitRange(1, 3), 1) === Core.Compiler.UnitRange(1, 1)
@test Core.Compiler.searchsorted(Core.Compiler.UnitRange(1, 3), 2) === Core.Compiler.UnitRange(2, 2)
@test Core.Compiler.searchsorted(Core.Compiler.UnitRange(1, 3), 4) === Core.Compiler.UnitRange(4, 3)

@test Core.Compiler.searchsorted([1:10;], 1, by=(x -> x >= 5)) === Core.Compiler.UnitRange(1, 4)
@test Core.Compiler.searchsorted([1:10;], 10, by=(x -> x >= 5)) === Core.Compiler.UnitRange(5, 10)
@test Core.Compiler.searchsorted([1:5; 1:5; 1:5], 1, 6, 10, Core.Compiler.Forward) === Core.Compiler.UnitRange(6, 6)
@test Core.Compiler.searchsorted(fill(1, 15), 1, 6, 10, Core.Compiler.Forward) === Core.Compiler.UnitRange(6, 10)

for (rg,I) in Any[(Core.Compiler.UnitRange(49, 57), 47:59),
(Core.Compiler.StepRange(1, 2, 17), -1:19)]
rg_r = Core.Compiler.reverse(rg)
rgv, rgv_r = Core.Compiler.collect(rg), Core.Compiler.collect(rg_r)
for i = I
@test Core.Compiler.searchsorted(rg,i) === Core.Compiler.searchsorted(rgv,i)
@test Core.Compiler.searchsorted(rg_r,i,rev=true) === Core.Compiler.searchsorted(rgv_r,i,rev=true)
end
end
end

@testset "basic sort" begin
v = [3,1,2]
@test v == [3,1,2]
@test Core.Compiler.sort!(v) === v == [1,2,3]
@test Core.Compiler.sort!(v, by = x -> -x) === v == [3,2,1]
@test Core.Compiler.sort!(v, by = x -> -x, < = >) === v == [1,2,3]
end

@testset "randomized sorting tests" begin
for n in [0, 1, 3, 10, 30, 100, 300], k in [0, 30, 2n]
v = rand(-1:k, n)
for by in [identity, x -> -x, x -> x^2 + .1x], lt in [<, >]
@test sort(v; by, lt) == Core.Compiler.sort!(copy(v); by, < = lt)
end
end
end