Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

factor basis Cassette features into CassetteBase.jl #54

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
53 changes: 53 additions & 0 deletions .github/workflows/CassetteBaseCI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: CassetteBaseCI

on:
push:
branches:
- master
pull_request:

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false # don't stop CI even when one of them fails
matrix:
include:
- version: '1' # current stable
os: ubuntu-latest
arch: x64
- version: '1.10' # lowerest version supported
os: ubuntu-latest
arch: x64
- version: '1.11-nightly' # next release
os: ubuntu-latest
arch: x64
- version: 'nightly' # dev
os: ubuntu-latest
arch: x64
- version: '1' # x86 ubuntu
os: ubuntu-latest
arch: x86
- version: '1' # x86 windows
os: windows-latest
arch: x86
- version: '1' # x64 windows
os: windows-latest
arch: x64
- version: '1' # x64 macOS
os: macos-latest
arch: x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- name: test CassetteBase
shell: julia --color=yes --project=. {0} # this is necessary for the next command to work on Windows
run: 'using Pkg; Pkg.activate("CassetteBase"); Pkg.instantiate(); Pkg.test(coverage=true)'
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
file: ./lcov.info
19 changes: 19 additions & 0 deletions CassetteBase/LICENSE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Copyright (c) 2024 JuliaHub, Inc. and other contributors.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
13 changes: 13 additions & 0 deletions CassetteBase/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name = "CassetteBase"
uuid = "6dd3e646-b1c5-42c7-94be-00277fa12e22"
authors = ["Shuhei Kadowaki <aviatesk@gmail.com>"]
version = "0.1.0"

[compat]
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
1 change: 1 addition & 0 deletions CassetteBase/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# CassetteBase.jl
165 changes: 165 additions & 0 deletions CassetteBase/src/CassetteBase.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
module CassetteBase

export cassette_transform!, generate_internalerr_ex, generate_lambda_ex

using Core.IR
using Core: SimpleVector

function cassette_transform!(src::CodeInfo, mi::MethodInstance, nargs::Int,
selfname::Symbol, fargsname::Symbol)
method = mi.def::Method
mnargs = Int(method.nargs)

src.slotnames = Symbol[selfname, fargsname, src.slotnames[mnargs+1:end]...]
src.slotflags = UInt8[ 0x00, 0x00, src.slotflags[mnargs+1:end]...]

code = src.code
fargsslot = SlotNumber(2)
precode = Any[]
local ssaid = 0
for i = 1:mnargs
if method.isva && i == mnargs
tuplecall = Expr(:call, tuple)
for j = i:nargs
push!(precode, Expr(:call, getfield, fargsslot, j))
ssaid += 1
push!(tuplecall.args, SSAValue(ssaid))
end
push!(precode, tuplecall)
else
push!(precode, Expr(:call, getfield, fargsslot, i))
end
ssaid += 1
end
prepend!(code, precode)
@static if VERSION < v"1.12.0-DEV.173"
prepend!(src.codelocs, [0 for i = 1:ssaid])
else
di = Core.Compiler.DebugInfoStream(mi, src.debuginfo, length(code))
src.debuginfo = Core.DebugInfo(di, length(code))
end
prepend!(src.ssaflags, [0x00 for i = 1:ssaid])
src.ssavaluetypes += ssaid
if @static isdefined(Base, :__has_internal_change) && Base.__has_internal_change(v"1.12-alpha", :codeinfonargs)
src.nargs = 2
src.isva = true
end

function map_slot_number(slot::Int)
@assert slot ≥ 1
if 1 ≤ slot ≤ mnargs
if method.isva && slot == mnargs
return SSAValue(ssaid)
else
return SSAValue(slot)
end
else
return SlotNumber(slot - mnargs + 2)
end
end
map_ssa_value(id::Int) = id + ssaid
for i = (ssaid+1:length(code))
code[i] = transform_stmt(code[i], map_slot_number, map_ssa_value, mi.def.sig, mi.sparam_vals)
end

src.edges = MethodInstance[mi]
src.method_for_inference_limit_heuristics = method

return src
end

function transform_stmt(@nospecialize(x), map_slot_number, map_ssa_value, @nospecialize(spsig), sparams::SimpleVector)
transform(@nospecialize x′) = transform_stmt(x′, map_slot_number, map_ssa_value, spsig, sparams)
if isa(x, Expr)
head = x.head
if head === :call
return Expr(:call, SlotNumber(1), map(transform, x.args[1:end])...)
elseif head === :foreigncall
arg1 = x.args[1]
if Meta.isexpr(arg1, :call)
# first argument of :foreigncall may be a magic tuple call, and it should be preserved
arg1 = Expr(:call, map(transform, arg1.args)...)
else
arg1 = transform(x.args[1])
end
arg2 = @ccall jl_instantiate_type_in_env(x.args[2]::Any, spsig::Any, sparams::Ptr{Any})::Any
arg3 = Core.svec(Any[
@ccall jl_instantiate_type_in_env(argt::Any, spsig::Any, sparams::Ptr{Any})::Any
for argt in x.args[3]::SimpleVector ]...)
return Expr(:foreigncall, arg1, arg2, arg3, map(transform, x.args[4:end])...)
elseif head === :enter
return Expr(:enter, map_ssa_value(x.args[1]::Int))
elseif head === :static_parameter
return sparams[x.args[1]::Int]
elseif head === :isdefined
arg1 = x.args[1]
if Meta.isexpr(arg1, :static_parameter)
return 1 ≤ arg1.args[1]::Int ≤ length(sparams)
end
end
return Expr(head, map(transform, x.args)...)
elseif isa(x, GotoNode)
return GotoNode(map_ssa_value(x.label))
elseif isa(x, GotoIfNot)
return GotoIfNot(transform(x.cond), map_ssa_value(x.dest))
elseif isa(x, ReturnNode)
return ReturnNode(transform(x.val))
elseif isa(x, SlotNumber)
return map_slot_number(x.id)
elseif isa(x, NewvarNode)
return NewvarNode(map_slot_number(x.slot.id))
elseif isa(x, SSAValue)
return SSAValue(map_ssa_value(x.id))
elseif @static @isdefined(EnterNode) && isa(x, EnterNode)
if isdefined(x, :scope)
return EnterNode(map_ssa_value(x.catch_dest), transform(x.scope))
else
return EnterNode(map_ssa_value(x.catch_dest))
end
end
return x
end

struct CassetteInternalError
err
bt::Vector
context::Symbol
metadata # allow preserving arbitrary data for debugging
function CassetteInternalError(err, bt::Vector, context::Symbol, metadata=nothing)
@nospecialize err metadata
new(err, bt, context, metadata)
end
end
function Base.showerror(io::IO, err::CassetteInternalError)
print(io, "Internal error happened in `$(err.context)`:")
println(io)
buf = IOBuffer()
ioctx = IOContext(buf, IOContext(io))
Base.showerror(ioctx, err.err)
Base.show_backtrace(ioctx, err.bt)
printstyled(io, " ┌", '─'^48, '\n'; color=:red)
for l in split(String(take!(buf)), '\n')
printstyled(io, " │ "; color=:red)
println(io, l)
end
printstyled(io, " └", '─'^48; color=:red)
end

function generate_internalerr_ex(err, bt::Vector, context::Symbol,
world::UInt, source::LineNumberNode,
argnames::SimpleVector, spnames::SimpleVector,
metadata=nothing)
@nospecialize err metadata
throw_ex = :(throw($CassetteInternalError(
$(QuoteNode(err)), $bt, $(QuoteNode(context)), $(QuoteNode(metadata)))))
return generate_lambda_ex(world, source, argnames, spnames, throw_ex)
end

function generate_lambda_ex(world::UInt, source::LineNumberNode,
argnames::SimpleVector, spnames::SimpleVector,
body::Expr)
stub = Core.GeneratedFunctionStub(identity, argnames, spnames)
return stub(world, source, body)
end

end # module CassetteBase
5 changes: 5 additions & 0 deletions CassetteBase/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Test

@testset "CassetteBase" begin
@testset "test_basic.jl" include("test_basic.jl")
end
83 changes: 83 additions & 0 deletions CassetteBase/test/test_basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
module test_basic

using Test, CassetteBase

function make_basic_generator(selfname::Symbol, fargsname::Symbol, raise::Bool)
function basic_generator(world::UInt, source::LineNumberNode, passtype, fargtypes)
@nospecialize passtype fargtypes
try
return generate_basic_src(world, source, passtype, fargtypes,
selfname, fargsname; raise)
catch err
# internal error happened - return an expression to raise the special exception
return generate_internalerr_ex(
err, #=bt=#catch_backtrace(), #=context=#:basic_generator, world, source,
#=argnames=#Core.svec(selfname, fargsname), #=spnames=#Core.svec(),
#=metadata=#(; world, source, passtype, fargtypes))
end
end
end
function generate_basic_src(world::UInt, source::LineNumberNode, passtype, fargtypes,
selfname::Symbol, fargsname::Symbol; raise::Bool)
@nospecialize passtype fargtypes
tt = Base.to_tuple_type(fargtypes)
match = Base._which(tt; raise, world)
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
mi = Core.Compiler.specialize_method(match)
src = Core.Compiler.retrieve_code_info(mi, world)
src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it
cassette_transform!(src, mi, length(fargtypes), selfname, fargsname)
return src
end

struct BasicPass end
@eval function (pass::BasicPass)(fargs...)
$(Expr(:meta, :generated, make_basic_generator(:pass, :fargs, #=raise=#false)))
return first(fargs)(Base.tail(fargs)...)
end
let pass = BasicPass()
@test pass(sin, 1) == sin(1)
@test_throws MethodError pass("1") do x; sin(x); end
end

struct RaisePass end
@eval function (pass::RaisePass)(fargs...)
$(Expr(:meta, :generated, make_basic_generator(:pass, :fargs, #=raise=#true)))
return first(fargs)(Base.tail(fargs)...)
end
let pass = RaisePass()
@test pass(sin, 1) == sin(1)
@test_throws CassetteBase.CassetteInternalError pass("1") do x; sin(x); end
local err
try
pass("1") do
sin(x)
end
catch e
err = e
end
@test @isdefined(err)
@test err isa CassetteBase.CassetteInternalError
msg = let
buf = IOBuffer()
showerror(buf, err)
String(take!(buf))
end
@test occursin("Internal error happened in `basic_generator`:", msg)
local err_expected
try
Base._which(Tuple{typeof(sin),String})
catch e
err_expected = e
end
@test @isdefined(err_expected)
@test err.err == err_expected
msg_expected = let
buf = IOBuffer()
showerror(buf, err_expected)
String(take!(buf))
end
@test occursin(msg_expected, msg)
end

end # module test_basic
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2022 JuliaHub, Inc. and other contributors.
Copyright (c) 2024 JuliaHub, Inc. and other contributors.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
Loading