Skip to content

Commit

Permalink
Format code (#562)
Browse files Browse the repository at this point in the history
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and enzyme-ci-bot[bot] authored Jan 18, 2025
1 parent 6310f83 commit fc4e53d
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 106 deletions.
24 changes: 16 additions & 8 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ function get_field_offset(T::Type, path)
offset = 0
current_type = T


for field in path
# Get the field index
field_idx = if field isa Integer
Expand All @@ -450,9 +449,7 @@ function get_field_offset(T::Type, path)

# Update current_type to the field's type for next iteration
current_type = tcurrent_type

end


return offset
end
Expand All @@ -470,7 +467,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
threaddim = CUDA.CuDim3(threads)

if convert == Val(true)
args = recudaconvert.(args)
args = recudaconvert.(args)
end

mlir_args = MLIR.IR.Value[]
Expand Down Expand Up @@ -662,7 +659,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end

location = MLIR.IR.Location()
@assert length(restys) == length(aliases)
@assert length(restys) == length(aliases)
call = MLIR.Dialects.enzymexla.kernel_call(
blk_operands...,
mlir_args;
Expand Down Expand Up @@ -723,13 +720,19 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(A::Type{<:CuTracedArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
@nospecialize(A::Type{<:CuTracedArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
return A
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
@nospecialize(A::Type{<:CUDA.CuArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
T = eltype(A)
N = ndims(A)
Expand All @@ -746,7 +749,12 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
end

function Reactant.make_tracer(
seen, @nospecialize(prev::CUDA.CuArray), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs...
seen,
@nospecialize(prev::CUDA.CuArray),
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type = Union{}),
kwargs...,
)
RT = Core.Typeof(prev)
if haskey(seen, prev)
Expand Down
5 changes: 4 additions & 1 deletion ext/ReactantOffsetArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ using OffsetArrays: OffsetArray
using Reactant: Reactant, MLIR, Ops, TracedRArray

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(OA::Type{<:OffsetArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type=Union{})
@nospecialize(OA::Type{<:OffsetArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type = Union{})
)
N = ndims(OA)
T = OffsetArrays.parenttype(OA)
Expand Down
18 changes: 8 additions & 10 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,14 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
",",
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
passes = [
"inline{default-pipeline=canonicalize max-iterations=4}"
]
passes = ["inline{default-pipeline=canonicalize max-iterations=4}"]
if sroa
push!(passes, "sroa-wrappers")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")
push!(passes, "sroa-wrappers")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")
end
push!(passes, func_passes)
return join(passes,
',',
)
return join(passes, ',')
end

# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
Expand Down Expand Up @@ -389,7 +385,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
if DEBUG_KERNEL[]
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult")
curesulthandler = XLA.Libdl.dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
@assert curesulthandler !== nothing
curesulthandler = Base.reinterpret(UInt, curesulthandler)
kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce"
Expand Down
23 changes: 19 additions & 4 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,37 @@ else
end

@static if isdefined(Core, :BFloat16)
const ReactantComplexFloat = Union{Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}}
const ReactantComplexFloat = Union{
Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}
}
else
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{Complex{Int8},Complex{UInt8},Complex{Int16},Complex{UInt16},Complex{Int32},Complex{UInt32},Complex{Int64},Complex{UInt64},Complex{Int128},Complex{UInt128}}
const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,Base.uniontypes(ReactantComplexFloat)...
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}

abstract type RNumber{T<:ReactantPrimitive} <: Number end
Expand Down
Loading

0 comments on commit fc4e53d

Please sign in to comment.