diff --git a/.github/workflows/ci.yml b/.github/workflows/CI.yml similarity index 100% rename from .github/workflows/ci.yml rename to .github/workflows/CI.yml diff --git a/.github/workflows/CassetteBaseCI.yml b/.github/workflows/CassetteBaseCI.yml new file mode 100644 index 0000000..2a7637c --- /dev/null +++ b/.github/workflows/CassetteBaseCI.yml @@ -0,0 +1,53 @@ +name: CassetteBaseCI + +on: + push: + branches: + - master + pull_request: + +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false # don't stop CI even when one of them fails + matrix: + include: + - version: '1' # current stable + os: ubuntu-latest + arch: x64 + - version: '1.10' # lowerest version supported + os: ubuntu-latest + arch: x64 + - version: '1.11-nightly' # next release + os: ubuntu-latest + arch: x64 + - version: 'nightly' # dev + os: ubuntu-latest + arch: x64 + - version: '1' # x86 ubuntu + os: ubuntu-latest + arch: x86 + - version: '1' # x86 windows + os: windows-latest + arch: x86 + - version: '1' # x64 windows + os: windows-latest + arch: x64 + - version: '1' # x64 macOS + os: macos-latest + arch: x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - name: test CassetteBase + shell: julia --color=yes --project=. {0} # this is necessary for the next command to work on Windows + run: 'using Pkg; Pkg.activate("CassetteBase"); Pkg.instantiate(); Pkg.test(coverage=true)' + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + file: ./lcov.info diff --git a/CassetteBase/LICENSE.md b/CassetteBase/LICENSE.md new file mode 100644 index 0000000..d325b45 --- /dev/null +++ b/CassetteBase/LICENSE.md @@ -0,0 +1,19 @@ +Copyright (c) 2024 JuliaHub, Inc. and other contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/CassetteBase/Project.toml b/CassetteBase/Project.toml new file mode 100644 index 0000000..032e484 --- /dev/null +++ b/CassetteBase/Project.toml @@ -0,0 +1,13 @@ +name = "CassetteBase" +uuid = "6dd3e646-b1c5-42c7-94be-00277fa12e22" +authors = ["Shuhei Kadowaki "] +version = "0.1.0" + +[compat] +julia = "1.10" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] diff --git a/CassetteBase/README.md b/CassetteBase/README.md new file mode 100644 index 0000000..69cf5ef --- /dev/null +++ b/CassetteBase/README.md @@ -0,0 +1 @@ +# CassetteBase.jl diff --git a/CassetteBase/src/CassetteBase.jl b/CassetteBase/src/CassetteBase.jl new file mode 100644 index 0000000..6d83acc --- /dev/null +++ b/CassetteBase/src/CassetteBase.jl @@ -0,0 +1,165 @@ +module CassetteBase + +export cassette_transform!, generate_internalerr_ex, generate_lambda_ex + +using Core.IR +using Core: SimpleVector + +function cassette_transform!(src::CodeInfo, mi::MethodInstance, nargs::Int, + selfname::Symbol, fargsname::Symbol) + method = mi.def::Method + mnargs = Int(method.nargs) + + src.slotnames = Symbol[selfname, fargsname, src.slotnames[mnargs+1:end]...] + src.slotflags = UInt8[ 0x00, 0x00, src.slotflags[mnargs+1:end]...] + + code = src.code + fargsslot = SlotNumber(2) + precode = Any[] + local ssaid = 0 + for i = 1:mnargs + if method.isva && i == mnargs + tuplecall = Expr(:call, tuple) + for j = i:nargs + push!(precode, Expr(:call, getfield, fargsslot, j)) + ssaid += 1 + push!(tuplecall.args, SSAValue(ssaid)) + end + push!(precode, tuplecall) + else + push!(precode, Expr(:call, getfield, fargsslot, i)) + end + ssaid += 1 + end + prepend!(code, precode) + @static if VERSION < v"1.12.0-DEV.173" + prepend!(src.codelocs, [0 for i = 1:ssaid]) + else + di = Core.Compiler.DebugInfoStream(mi, src.debuginfo, length(code)) + src.debuginfo = Core.DebugInfo(di, length(code)) + end + prepend!(src.ssaflags, [0x00 for i = 1:ssaid]) + src.ssavaluetypes += ssaid + if @static isdefined(Base, :__has_internal_change) && Base.__has_internal_change(v"1.12-alpha", :codeinfonargs) + src.nargs = 2 + src.isva = true + end + + function map_slot_number(slot::Int) + @assert slot ≥ 1 + if 1 ≤ slot ≤ mnargs + if method.isva && slot == mnargs + return SSAValue(ssaid) + else + return SSAValue(slot) + end + else + return SlotNumber(slot - mnargs + 2) + end + end + map_ssa_value(id::Int) = id + ssaid + for i = (ssaid+1:length(code)) + code[i] = transform_stmt(code[i], map_slot_number, map_ssa_value, mi.def.sig, mi.sparam_vals) + end + + src.edges = MethodInstance[mi] + src.method_for_inference_limit_heuristics = method + + return src +end + +function transform_stmt(@nospecialize(x), map_slot_number, map_ssa_value, @nospecialize(spsig), sparams::SimpleVector) + transform(@nospecialize x′) = transform_stmt(x′, map_slot_number, map_ssa_value, spsig, sparams) + if isa(x, Expr) + head = x.head + if head === :call + return Expr(:call, SlotNumber(1), map(transform, x.args[1:end])...) + elseif head === :foreigncall + arg1 = x.args[1] + if Meta.isexpr(arg1, :call) + # first argument of :foreigncall may be a magic tuple call, and it should be preserved + arg1 = Expr(:call, map(transform, arg1.args)...) + else + arg1 = transform(x.args[1]) + end + arg2 = @ccall jl_instantiate_type_in_env(x.args[2]::Any, spsig::Any, sparams::Ptr{Any})::Any + arg3 = Core.svec(Any[ + @ccall jl_instantiate_type_in_env(argt::Any, spsig::Any, sparams::Ptr{Any})::Any + for argt in x.args[3]::SimpleVector ]...) + return Expr(:foreigncall, arg1, arg2, arg3, map(transform, x.args[4:end])...) + elseif head === :enter + return Expr(:enter, map_ssa_value(x.args[1]::Int)) + elseif head === :static_parameter + return sparams[x.args[1]::Int] + elseif head === :isdefined + arg1 = x.args[1] + if Meta.isexpr(arg1, :static_parameter) + return 1 ≤ arg1.args[1]::Int ≤ length(sparams) + end + end + return Expr(head, map(transform, x.args)...) + elseif isa(x, GotoNode) + return GotoNode(map_ssa_value(x.label)) + elseif isa(x, GotoIfNot) + return GotoIfNot(transform(x.cond), map_ssa_value(x.dest)) + elseif isa(x, ReturnNode) + return ReturnNode(transform(x.val)) + elseif isa(x, SlotNumber) + return map_slot_number(x.id) + elseif isa(x, NewvarNode) + return NewvarNode(map_slot_number(x.slot.id)) + elseif isa(x, SSAValue) + return SSAValue(map_ssa_value(x.id)) + elseif @static @isdefined(EnterNode) && isa(x, EnterNode) + if isdefined(x, :scope) + return EnterNode(map_ssa_value(x.catch_dest), transform(x.scope)) + else + return EnterNode(map_ssa_value(x.catch_dest)) + end + end + return x +end + +struct CassetteInternalError + err + bt::Vector + context::Symbol + metadata # allow preserving arbitrary data for debugging + function CassetteInternalError(err, bt::Vector, context::Symbol, metadata=nothing) + @nospecialize err metadata + new(err, bt, context, metadata) + end +end +function Base.showerror(io::IO, err::CassetteInternalError) + print(io, "Internal error happened in `$(err.context)`:") + println(io) + buf = IOBuffer() + ioctx = IOContext(buf, IOContext(io)) + Base.showerror(ioctx, err.err) + Base.show_backtrace(ioctx, err.bt) + printstyled(io, " ┌", '─'^48, '\n'; color=:red) + for l in split(String(take!(buf)), '\n') + printstyled(io, " │ "; color=:red) + println(io, l) + end + printstyled(io, " └", '─'^48; color=:red) +end + +function generate_internalerr_ex(err, bt::Vector, context::Symbol, + world::UInt, source::LineNumberNode, + argnames::SimpleVector, spnames::SimpleVector, + metadata=nothing) + @nospecialize err metadata + throw_ex = :(throw($CassetteInternalError( + $(QuoteNode(err)), $bt, $(QuoteNode(context)), $(QuoteNode(metadata))))) + return generate_lambda_ex(world, source, argnames, spnames, throw_ex) +end + +function generate_lambda_ex(world::UInt, source::LineNumberNode, + argnames::SimpleVector, spnames::SimpleVector, + body::Expr) + stub = Core.GeneratedFunctionStub(identity, argnames, spnames) + return stub(world, source, body) +end + +end # module CassetteBase diff --git a/CassetteBase/test/runtests.jl b/CassetteBase/test/runtests.jl new file mode 100644 index 0000000..2496132 --- /dev/null +++ b/CassetteBase/test/runtests.jl @@ -0,0 +1,5 @@ +using Test + +@testset "CassetteBase" begin + @testset "test_basic.jl" include("test_basic.jl") +end diff --git a/CassetteBase/test/test_basic.jl b/CassetteBase/test/test_basic.jl new file mode 100644 index 0000000..f065221 --- /dev/null +++ b/CassetteBase/test/test_basic.jl @@ -0,0 +1,83 @@ +module test_basic + +using Test, CassetteBase + +function make_basic_generator(selfname::Symbol, fargsname::Symbol, raise::Bool) + function basic_generator(world::UInt, source::LineNumberNode, passtype, fargtypes) + @nospecialize passtype fargtypes + try + return generate_basic_src(world, source, passtype, fargtypes, + selfname, fargsname; raise) + catch err + # internal error happened - return an expression to raise the special exception + return generate_internalerr_ex( + err, #=bt=#catch_backtrace(), #=context=#:basic_generator, world, source, + #=argnames=#Core.svec(selfname, fargsname), #=spnames=#Core.svec(), + #=metadata=#(; world, source, passtype, fargtypes)) + end + end +end +function generate_basic_src(world::UInt, source::LineNumberNode, passtype, fargtypes, + selfname::Symbol, fargsname::Symbol; raise::Bool) + @nospecialize passtype fargtypes + tt = Base.to_tuple_type(fargtypes) + match = Base._which(tt; raise, world) + match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError + mi = Core.Compiler.specialize_method(match) + src = Core.Compiler.retrieve_code_info(mi, world) + src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it + cassette_transform!(src, mi, length(fargtypes), selfname, fargsname) + return src +end + +struct BasicPass end +@eval function (pass::BasicPass)(fargs...) + $(Expr(:meta, :generated, make_basic_generator(:pass, :fargs, #=raise=#false))) + return first(fargs)(Base.tail(fargs)...) +end +let pass = BasicPass() + @test pass(sin, 1) == sin(1) + @test_throws MethodError pass("1") do x; sin(x); end +end + +struct RaisePass end +@eval function (pass::RaisePass)(fargs...) + $(Expr(:meta, :generated, make_basic_generator(:pass, :fargs, #=raise=#true))) + return first(fargs)(Base.tail(fargs)...) +end +let pass = RaisePass() + @test pass(sin, 1) == sin(1) + @test_throws CassetteBase.CassetteInternalError pass("1") do x; sin(x); end + local err + try + pass("1") do + sin(x) + end + catch e + err = e + end + @test @isdefined(err) + @test err isa CassetteBase.CassetteInternalError + msg = let + buf = IOBuffer() + showerror(buf, err) + String(take!(buf)) + end + @test occursin("Internal error happened in `basic_generator`:", msg) + local err_expected + try + Base._which(Tuple{typeof(sin),String}) + catch e + err_expected = e + end + @test @isdefined(err_expected) + @test err.err == err_expected + msg_expected = let + buf = IOBuffer() + showerror(buf, err_expected) + String(take!(buf)) + end + @test occursin(msg_expected, msg) +end + +end # module test_basic diff --git a/LICENSE.md b/LICENSE.md index 66349e5..d325b45 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,4 +1,4 @@ -Copyright (c) 2022 JuliaHub, Inc. and other contributors. +Copyright (c) 2024 JuliaHub, Inc. and other contributors. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal