From fc4e53ddb2bc469197419a58556621e425ff51cd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 18:20:11 -0600 Subject: [PATCH] Format code (#562) Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> --- ext/ReactantCUDAExt.jl | 24 ++-- ext/ReactantOffsetArraysExt.jl | 5 +- src/Compiler.jl | 18 ++- src/Reactant.jl | 23 +++- src/Tracing.jl | 202 +++++++++++++++++++++++++-------- test/integration/python.jl | 14 +-- test/struct.jl | 9 +- test/tracing.jl | 91 ++++++++++----- 8 files changed, 280 insertions(+), 106 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 437cf77f1..8cb820eb2 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -428,7 +428,6 @@ function get_field_offset(T::Type, path) offset = 0 current_type = T - for field in path # Get the field index field_idx = if field isa Integer @@ -450,9 +449,7 @@ function get_field_offset(T::Type, path) # Update current_type to the field's type for next iteration current_type = tcurrent_type - end - return offset end @@ -470,7 +467,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( threaddim = CUDA.CuDim3(threads) if convert == Val(true) - args = recudaconvert.(args) + args = recudaconvert.(args) end mlir_args = MLIR.IR.Value[] @@ -662,7 +659,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( end location = MLIR.IR.Location() - @assert length(restys) == length(aliases) + @assert length(restys) == length(aliases) call = MLIR.Dialects.enzymexla.kernel_call( blk_operands..., mlir_args; @@ -723,13 +720,19 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(A::Type{<:CuTracedArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(A::Type{<:CuTracedArray}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) ) return A end Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(A::Type{<:CUDA.CuArray}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) ) T = eltype(A) N = ndims(A) @@ -746,7 +749,12 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( end function Reactant.make_tracer( - seen, @nospecialize(prev::CUDA.CuArray), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs... + seen, + @nospecialize(prev::CUDA.CuArray), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + kwargs..., ) RT = Core.Typeof(prev) if haskey(seen, prev) diff --git a/ext/ReactantOffsetArraysExt.jl b/ext/ReactantOffsetArraysExt.jl index dedfb35e9..fc77ef0e1 100644 --- a/ext/ReactantOffsetArraysExt.jl +++ b/ext/ReactantOffsetArraysExt.jl @@ -5,7 +5,10 @@ using OffsetArrays: OffsetArray using Reactant: Reactant, MLIR, Ops, TracedRArray Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(OA::Type{<:OffsetArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type=Union{}) + @nospecialize(OA::Type{<:OffsetArray}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type = Union{}) ) N = ndims(OA) T = OffsetArrays.parenttype(OA) diff --git a/src/Compiler.jl b/src/Compiler.jl index 931a5db6f..7f898d318 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -298,18 +298,14 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) ",", ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") - passes = [ - "inline{default-pipeline=canonicalize max-iterations=4}" - ] + passes = ["inline{default-pipeline=canonicalize max-iterations=4}"] if sroa - push!(passes, "sroa-wrappers") - push!(passes, "libdevice-funcs-raise") - push!(passes, "canonicalize") + push!(passes, "sroa-wrappers") + push!(passes, "libdevice-funcs-raise") + push!(passes, "canonicalize") end push!(passes, func_passes) - return join(passes, - ',', - ) + return join(passes, ',') end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate @@ -389,7 +385,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end if DEBUG_KERNEL[] - curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult") + curesulthandler = XLA.Libdl.dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) @assert curesulthandler !== nothing curesulthandler = Base.reinterpret(UInt, curesulthandler) kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce" diff --git a/src/Reactant.jl b/src/Reactant.jl index 1abb98f9b..20cba2573 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -24,22 +24,37 @@ else end @static if isdefined(Core, :BFloat16) - const ReactantComplexFloat = Union{Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}} + const ReactantComplexFloat = Union{ + Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64} + } else const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}} end const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} -const ReactantComplexInt = Union{Complex{Int8},Complex{UInt8},Complex{Int16},Complex{UInt16},Complex{Int32},Complex{UInt32},Complex{Int64},Complex{UInt64},Complex{Int128},Complex{UInt128}} +const ReactantComplexInt = Union{ + Complex{Int8}, + Complex{UInt8}, + Complex{Int16}, + Complex{UInt16}, + Complex{Int32}, + Complex{UInt32}, + Complex{Int64}, + Complex{UInt64}, + Complex{Int128}, + Complex{UInt128}, +} const ReactantFloatInt = Union{ Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... } const ReactantPrimitive = Union{ - Bool,Base.uniontypes(ReactantFloatInt)..., - Base.uniontypes(ReactantComplexInt)...,Base.uniontypes(ReactantComplexFloat)... + Bool, + Base.uniontypes(ReactantFloatInt)..., + Base.uniontypes(ReactantComplexInt)..., + Base.uniontypes(ReactantComplexFloat)..., } abstract type RNumber{T<:ReactantPrimitive} <: Number end diff --git a/src/Tracing.jl b/src/Tracing.jl index c0e9aba8b..34f5d0643 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -7,7 +7,9 @@ NoStopTracedTrack = 6 end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type) +) if T === Any return T end @@ -131,18 +133,41 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, throw(NoFieldMatchError(T, TT2)) end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{Union{}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) return T end -for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, AbstractFloat, Integer, RNumber) - @eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +for T in ( + DataType, + Module, + Nothing, + Symbol, + AbstractChar, + AbstractString, + AbstractFloat, + Integer, + RNumber, +) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:$T}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) + ) return T end end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:ReactantPrimitive}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type) ) if Mode == ArrayToConcrete && T <: track_numbers return ConcreteRNumber{T} @@ -151,7 +176,10 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(C::Type{<:Complex}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type) + @nospecialize(C::Type{<:Complex}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type) ) if !(C isa UnionAll) return Complex{traced_type_inner(C.parameters[1], seen, mode, track_numbers)} @@ -160,7 +188,12 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Function}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) # functions are directly returned if sizeof(T) == 0 return T @@ -187,7 +220,12 @@ end @inline is_concrete_tuple(x::T2) where {T2} = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Tuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll throw(AssertionError("Type $T is not concrete type or concrete tuple")) elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) @@ -201,19 +239,30 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple return Tuple{TT...} end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:NamedTuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) N = T.parameters[1] V = T.parameters[2] return NamedTuple{N,traced_type_inner(V, seen, mode, track_numbers)} end - Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict}) = nothing -Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict{K}}) where K = K +Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict{K}}) where {K} = K Base.@nospecializeinfer @inline dict_value(::Type{<:AbstractDict}) = nothing -Base.@nospecializeinfer @inline dict_value(::Type{<:(AbstractDict{K,V} where K)}) where V = V +Base.@nospecializeinfer @inline dict_value( + ::Type{<:(AbstractDict{K,V} where {K})} +) where {V} = V -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:AbstractDict}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:AbstractDict}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) V = dict_value(T) if V === nothing return T @@ -231,13 +280,16 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Abstr if K !== nothing return dictty{K,V2} else - return (dictty{KT,V2} where KT) + return (dictty{KT,V2} where {KT}) end end end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T0::Type{<:ConcreteRNumber}), seen, mode::TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(T0::Type{<:ConcreteRNumber}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) ) T = T0.parameters[1] if mode == ConcreteToTraced @@ -248,15 +300,22 @@ Base.@nospecializeinfer function traced_type_inner( throw("Abstract RNumber cannot be made concrete") end end - -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = UnionAll(TV.var, base_typet(TV.body)) -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = TracedRArray{TV.parameters...} - -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = UnionAll(TV.var, base_typec(TV.body)) -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} + +Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = + UnionAll(TV.var, base_typet(TV.body)) +Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = + TracedRArray{TV.parameters...} + +Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = + UnionAll(TV.var, base_typec(TV.body)) +Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = + (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRArray}), seen, mode::TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:ConcreteRArray}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) ) if mode == ConcreteToTraced return base_typet(T) @@ -267,7 +326,12 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:ConcreteRNG}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:ConcreteRNG}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) if mode == ConcreteToTraced return TracedRNG elseif mode == TracedToConcrete @@ -278,7 +342,10 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Concr end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:TracedType}), seen, mode::TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:TracedType}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) ) T <: MissingTracedValue && error("TODO") if mode == ConcreteToTraced @@ -292,7 +359,12 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:TracedRNG}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) if mode == ConcreteToTraced throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete @@ -304,12 +376,20 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Trace end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:XLAArray}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) throw("XLA $T array cannot be traced") end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(A::Type{<:Array}), seen, mode::TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(A::Type{<:Array}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) ) T = eltype(A) N = ndims(A) @@ -321,13 +401,23 @@ Base.@nospecializeinfer function traced_type_inner( end for P in (Ptr, Core.LLVMPtr, Base.RefValue) - @eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{<:$P}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) + ) T = eltype(PT) return $P{traced_type_inner(T, seen, mode, track_numbers)} end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(VT::Type{<:Val}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type) +) if VT isa UnionAll return VT end @@ -338,7 +428,7 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val} throw("Val type $(Val{T}) cannot be traced") end -const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}() +const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() # function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type)) # @nospecialize @@ -431,16 +521,18 @@ const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}() # $(Expr(:meta, :generated, traced_type_generator)) # end -Base.@assume_effects :total @inline function traced_type(T::Type, ::Val{mode}, track_numbers::Type) where mode +Base.@assume_effects :total @inline function traced_type( + T::Type, ::Val{mode}, track_numbers::Type +) where {mode} cache = nothing cache_key = (mode, track_numbers) if haskey(traced_type_cache, cache_key) cache = traced_type_cache[cache_key] else - cache = Dict{Type, Type}() + cache = Dict{Type,Type}() traced_type_cache[cache_key] = cache end - res1 = traced_type_inner(T, cache, mode, track_numbers) + return res1 = traced_type_inner(T, cache, mode, track_numbers) end abstract type TracedTypeException <: Exception end @@ -467,7 +559,7 @@ end function make_tracer( seen, - @nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}), + @nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}), @nospecialize(path), mode; kwargs..., @@ -483,7 +575,7 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, - @nospecialize(track_numbers::Type=Union{}), + @nospecialize(track_numbers::Type = Union{}), kwargs..., ) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -584,7 +676,9 @@ function make_tracer( return res end -function make_tracer(seen, prev::ConcreteRNumber{T}, @nospecialize(path), mode; kwargs...) where {T} +function make_tracer( + seen, prev::ConcreteRNumber{T}, @nospecialize(path), mode; kwargs... +) where {T} if mode == ArrayToConcrete return prev end @@ -739,7 +833,12 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::Number), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs... + seen, + @nospecialize(prev::Number), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + kwargs..., ) RT = Core.Typeof(prev) if RT <: track_numbers @@ -792,7 +891,12 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::Array), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs... + seen, + @nospecialize(prev::Array), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + kwargs..., ) RT = Core.Typeof(prev) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -822,9 +926,7 @@ function make_tracer( return newa end -function make_tracer( - seen, @nospecialize(prev::Tuple), @nospecialize(path), mode; kwargs... -) +function make_tracer(seen, @nospecialize(prev::Tuple), @nospecialize(path), mode; kwargs...) return ( ( make_tracer(seen, v, append_path(path, i), mode; kwargs...) for @@ -838,7 +940,7 @@ function make_tracer( @nospecialize(prev::NamedTuple), @nospecialize(path), mode; - @nospecialize(track_numbers::Type=Union{}), + @nospecialize(track_numbers::Type = Union{}), kwargs..., ) NT = Core.Typeof(prev) @@ -882,15 +984,23 @@ end return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers) end -function to_rarray_internal(@nospecialize(::TracedRArray), @nospecialize(track_numbers::Type)) +function to_rarray_internal( + @nospecialize(::TracedRArray), @nospecialize(track_numbers::Type) +) return error("Cannot convert TracedRArray to ConcreteRArray") end -@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type)) = x -@inline function to_rarray_internal(@nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type)) +@inline to_rarray_internal( + @nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type) +) = x +@inline function to_rarray_internal( + @nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type) +) return ConcreteRArray(x) end -@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)) = x +@inline to_rarray_internal( + @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) +) = x @inline function to_rarray_internal( @nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type) ) diff --git a/test/integration/python.jl b/test/integration/python.jl index 91921dedf..128f11950 100644 --- a/test/integration/python.jl +++ b/test/integration/python.jl @@ -5,13 +5,13 @@ using Test # Jax on Github CI dislikes X86 macos @static if !Sys.isapple() || Sys.ARCH != :x86_64 -using PythonCall + using PythonCall -@testset "PythonCall" begin - jax = pyimport("jax") + @testset "PythonCall" begin + jax = pyimport("jax") - result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3])) - @test typeof(result) == ConcreteRNumber{Float32} - @test result ≈ 6 + result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3])) + @test typeof(result) == ConcreteRNumber{Float32} + @test result ≈ 6 + end end -end \ No newline at end of file diff --git a/test/struct.jl b/test/struct.jl index 9f9f8930c..aadba4839 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -22,9 +22,14 @@ mutable struct MutableMockTensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N inds::Vector{Symbol} end -Base.@nospecializeinfer function Reactant.traced_type_inner(@nospecialize(A::Type{<:MockTensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(A::Type{<:MockTensor}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) +) T2 = Reactant.traced_type_inner(A.parameters[3], seen, mode, track_numbers) - MT = MockTensor{eltype(T2), ndims(A), T2} + MT = MockTensor{eltype(T2),ndims(A),T2} return MT end diff --git a/test/tracing.jl b/test/tracing.jl index 9543b841d..c196f562b 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -57,33 +57,65 @@ using Test (Complex{UInt128}, Complex{UInt128}, TracedRNumber{Complex{UInt128}}), # RArray types - (ConcreteRArray{Float64,0}, TracedRArray{Float64,0}, TracedRArray{Float64, 0}), - (ConcreteRArray{Float64,1}, TracedRArray{Float64,1}, TracedRArray{Float64, 1}), - (ConcreteRArray{Float64,2}, TracedRArray{Float64,2}, TracedRArray{Float64, 2}), - (ConcreteRArray{Float64,3}, TracedRArray{Float64,3}, TracedRArray{Float64, 3}), + ( + ConcreteRArray{Float64,0}, + TracedRArray{Float64,0}, + TracedRArray{Float64,0}, + ), + ( + ConcreteRArray{Float64,1}, + TracedRArray{Float64,1}, + TracedRArray{Float64,1}, + ), + ( + ConcreteRArray{Float64,2}, + TracedRArray{Float64,2}, + TracedRArray{Float64,2}, + ), + ( + ConcreteRArray{Float64,3}, + TracedRArray{Float64,3}, + TracedRArray{Float64,3}, + ), # Array types - (Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64}, 1}), - (Array{ConcreteRArray{Float64,2},1}, Array{TracedRArray{Float64,2},1}, Array{TracedRArray{Float64,2}, 1}), + (Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64},1}), + ( + Array{ConcreteRArray{Float64,2},1}, + Array{TracedRArray{Float64,2},1}, + Array{TracedRArray{Float64,2},1}, + ), # Union types - (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing, TracedRNumber{Int}}), + (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}), ( Union{Nothing,ConcreteRArray{Float64,1}}, Union{Nothing,TracedRArray{Float64,1}}, - Union{Nothing, TracedRArray{Float64, 1}} + Union{Nothing,TracedRArray{Float64,1}}, ), # Ptr types (Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}), - (Ptr{ConcreteRArray{Float64,1}}, Ptr{TracedRArray{Float64,1}}, Ptr{TracedRArray{Float64,1}}), - (Core.LLVMPtr{Float64}, Core.LLVMPtr{Float64}, Core.LLVMPtr{TracedRNumber{Float64}}), + ( + Ptr{ConcreteRArray{Float64,1}}, + Ptr{TracedRArray{Float64,1}}, + Ptr{TracedRArray{Float64,1}}, + ), + ( + Core.LLVMPtr{Float64}, + Core.LLVMPtr{Float64}, + Core.LLVMPtr{TracedRNumber{Float64}}, + ), ( Core.LLVMPtr{ConcreteRArray{Float64,1}}, Core.LLVMPtr{TracedRArray{Float64,1}}, - Core.LLVMPtr{TracedRArray{Float64,1}} + Core.LLVMPtr{TracedRArray{Float64,1}}, + ), + ( + Base.RefValue{Float64}, + Base.RefValue{Float64}, + Base.RefValue{TracedRNumber{Float64}}, ), - (Base.RefValue{Float64}, Base.RefValue{Float64}, Base.RefValue{TracedRNumber{Float64}}), ( Base.RefValue{ConcreteRArray{Float64,1}}, Base.RefValue{TracedRArray{Float64,1}}, @@ -94,23 +126,28 @@ using Test (Val{0}, Val{0}, Val{0}), (Val{0.5}, Val{0.5}, Val{0.5}), (Val{:x}, Val{:x}, Val{:x}), - - - (Dict{Int, ConcreteRArray{Float64,0}}, Dict{Int, TracedRArray{Float64,0}}, Dict{Int, TracedRArray{Float64, 0}}), + ( + Dict{Int,ConcreteRArray{Float64,0}}, + Dict{Int,TracedRArray{Float64,0}}, + Dict{Int,TracedRArray{Float64,0}}, + ), (Dict{Int}, Dict{Int}, Dict{Int}), (Dict, Dict, Dict), - ((Dict{A, ConcreteRArray{Float64,0}} where A), (Dict{A, TracedRArray{Float64,0}} where A), (Dict{A, TracedRArray{Float64,0}} where A)), - - (Base.Pairs{Symbol, Union{}}, Base.Pairs{Symbol, Union{}}, Base.Pairs{Symbol, Union{}}) + ( + (Dict{A,ConcreteRArray{Float64,0}} where {A}), + (Dict{A,TracedRArray{Float64,0}} where {A}), + (Dict{A,TracedRArray{Float64,0}} where {A}), + ), + ( + Base.Pairs{Symbol,Union{}}, + Base.Pairs{Symbol,Union{}}, + Base.Pairs{Symbol,Union{}}, + ), ] - tracedty = traced_type( - origty, Val(ConcreteToTraced), Union{} - ) + tracedty = traced_type(origty, Val(ConcreteToTraced), Union{}) @test tracedty == targetty - tracedty2 = traced_type( - origty, Val(ConcreteToTraced), ReactantPrimitive - ) + tracedty2 = traced_type(origty, Val(ConcreteToTraced), ReactantPrimitive) @test tracedty2 == targetty end @@ -121,7 +158,7 @@ using Test TracedRArray{Float64,3}, ] @test_throws Union{ErrorException,String} traced_type( - type, Val(ConcreteToTraced), Union{} + type, Val(ConcreteToTraced), Union{} ) end end @@ -130,9 +167,7 @@ using Test x::Vector{Float64} y::Union{Nothing,Node} end - @test_throws NoFieldMatchError traced_type( - Node, Val(ArrayToConcrete), Union{} - ) + @test_throws NoFieldMatchError traced_type(Node, Val(ArrayToConcrete), Union{}) end end