From cbecf17bd726d1ffcbf6fc3ad11b030fa8a5b12f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 19:24:13 +0530 Subject: [PATCH 01/15] feat: add materialize_traced_array for all other wrappers --- src/Compiler.jl | 3 + src/TracedUtils.jl | 54 ++-------- src/stdlibs/LinearAlgebra.jl | 204 ++++++++++++++++++++++++++++++++--- 3 files changed, 198 insertions(+), 63 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 62cf685b6..62611047b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -41,6 +41,9 @@ function create_result(tocopy::T, path, result_stores) where {T} elems = Union{Symbol,Expr}[] for i in 1:fieldcount(T) + # If the field is undefined we don't set it. A common example for this is `du2` + # for Tridiagonal + isdefined(tocopy, i) || continue ev = create_result(getfield(tocopy, i), append_path(path, i), result_stores) push!(elems, ev) end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index b69027549..a12c3cd10 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -3,7 +3,6 @@ # within compilation. However, it means these functions are a _lot_ faster to compile. module TracedUtils -using LinearAlgebra: LinearAlgebra using Adapt: Adapt, WrappedReshapedArray using ..Reactant: Reactant, @@ -19,34 +18,20 @@ using ..Reactant: Ops materialize_traced_array(x::TracedRArray) = x + materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] + function materialize_traced_array( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} ) where {T,N,M} return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) end -function materialize_traced_array( - x::LinearAlgebra.Transpose{TracedRNumber{T},TracedRArray{T,N}} -) where {T,N} - px = parent(x) - A = ndims(px) == 1 ? reshape(px, :, 1) : px - return permutedims(A, (2, 1)) -end -function materialize_traced_array( - x::LinearAlgebra.Adjoint{TracedRNumber{T},TracedRArray{T,N}} -) where {T,N} - return conj(materialize_traced_array(transpose(parent(x)))) -end + function materialize_traced_array( x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}} ) where {T,N,perm,iperm} return permutedims(parent(x), perm) end -function materialize_traced_array( - x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}} -) where {T} - return LinearAlgebra.diagm(parent(x)) -end get_mlir_data(x::TracedRNumber) = x.mlir_data set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) @@ -58,6 +43,7 @@ function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data return x end + function set_mlir_data!( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, data ) where {T,N,M} @@ -65,42 +51,14 @@ function set_mlir_data!( set_mlir_data!(parent(x), res_mlir_data) return x end -function set_mlir_data!( - x::LinearAlgebra.Transpose{TracedRNumber{T},TracedRArray{T,N}}, data -) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - px.mlir_data = ( - if ndims(px) == 1 - Ops.reshape(tdata, length(tdata)) - else - Ops.transpose(tdata, [2, 1]) - end - ).mlir_data - return x -end -function set_mlir_data!( - x::LinearAlgebra.Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data -) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - transposed_data = - ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) - px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data - return x -end + function set_mlir_data!( x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}, data ) where {T,N,perm,iperm} parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data return x end -function set_mlir_data!( - x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data -) where {T} - parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data - return x -end + function set_mlir_data!(x::AnyTracedRArray, data) setindex!(x, TracedRArray(data), axes(x)...) return x diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index c011f8aec..b5df5bb0f 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -1,19 +1,190 @@ module TracedLinearAlgebra -using ..Reactant -import ..TracedRArray -import ..TracedRNumber -import ..AnyTracedRArray -import ..AnyTracedRMatrix -import ..AnyTracedRVector - -import ..TracedUtils -using ..TracedUtils: get_mlir_data, materialize_traced_array, set_mlir_data! - -import ..Ops -import ..MLIR +using ..Reactant: + TracedRArray, + TracedRNumber, + AnyTracedRArray, + AnyTracedRMatrix, + AnyTracedRVector, + Ops, + MLIR + +using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data! + using LinearAlgebra +# Various Wrapper Arrays defined in LinearAlgebra +function materialize_traced_array( + x::Transpose{TracedRNumber{T},TracedRArray{T,N}} +) where {T,N} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end + +function materialize_traced_array( + x::Adjoint{TracedRNumber{T},TracedRArray{T,N}} +) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) +end + +function materialize_traced_array( + x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} + return LinearAlgebra.diagm(parent(x)) +end + +function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} + (_, update_function) = make_mlir_fn( + simple_update_overwrite, + (promote_to(TracedRNumber{T}, 0), promote_to(TracedRNumber{T}, 0)), + (), + string(gensym("update_computation")), + false; + return_dialect=:stablehlo, + no_args_in_result=true, + ) + update_computation = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody( + update_computation, MLIR.API.mlirOperationGetRegion(update_function, 0) + ) + MLIR.IR.rmfromparent!(update_function) + + init_array = Ops.constant(fill(zero(T), size(x))).mlir_data + + scatter_indices = Vector{Int64}[] + for i in (-1, 0, 1) + idxs = diagind(x, i, IndexCartesian()) + for idx in idxs + push!(scatter_indices, Int64[idx[1] - 1, idx[2] - 1]) + end + end + scatter_indices = + Ops.transpose(Ops.constant(reduce(hcat, scatter_indices)), [2, 1]).mlir_data + + updates = MLIR.IR.result( + MLIR.Dialects.stablehlo.concatenate( + [x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0 + ), + 1, + ) + + #! format: off + scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( + MLIR.IR.context(), + 0, Int64[], + 2, Int64[0, 1], + 0, Int64[], + 0, Int64[], + 2, Int64[0, 1], + 1 + ) + #! format: on + + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.scatter( + [init_array], + scatter_indices, + [updates]; + result_0=[mlir_type(TracedRArray{T,2}, size(x))], + update_computation, + scatter_dimension_numbers, + ), + 1, + ) + + return TracedRArray{T,2}((), res, size(x)) +end + +for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) + uAT = Symbol(:Unit, AT) + @eval begin + function TracedUtils.materialize_traced_array( + x::$(AT){T,TracedRArray{T,2}} + ) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(comp)) + return Ops.select(indicator, parent(x), zero(parent(x))) + end + + function TracedUtils.materialize_traced_array( + x::$(uAT){T,TracedRArray{T,2}} + ) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + nondiag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="NE") + x = materialize_traced_array($(AT)(parent(x))) + return Ops.select(nondiag_indicator, x, one.(x)) + end + end +end + +function TracedUtils.materialize_traced_array(x::Symmetric{T,TracedRArray{T,2}}) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + if x.uplo == 'L' + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="GT") + x_lt = Ops.select(indicator, parent(x), zero(parent(x))) + x_ltd = materialize_traced_array(LowerTriangular(parent(x))) + return Ops.add(x_lt, Ops.transpose(x_ltd, [2, 1])) + else + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="LT") + x_ut = Ops.select(indicator, parent(x), zero(parent(x))) + x_utd = materialize_traced_array(UpperTriangular(parent(x))) + return Ops.add(Ops.transpose(x_utd, [2, 1]), x_ut) + end +end + +function TracedUtils.set_mlir_data!( + x::Transpose{TracedRNumber{T},TracedRArray{T,N}}, data +) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end + +function TracedUtils.set_mlir_data!( + x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data +) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end + +function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T} + parent(x).mlir_data = diag(TracedRArray(data)).mlir_data + return x +end + +# TODO: UnitLowerTriangular +# TODO: LowerTriangular +# TODO: UnitUpperTriangular +# TODO: UpperTriangular +# TODO: Symmetric + +function set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} + tdata = TracedRArray(data) + set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) + set_mlir_data!(x.d, diag(tdata, 0).mlir_data) + set_mlir_data!(x.du, diag(tdata, 1).mlir_data) + return x +end + +# Core functions function LinearAlgebra.mul!( @nospecialize(C::TracedRArray{T,1}), @nospecialize(A::AnyTracedRMatrix), @@ -23,7 +194,7 @@ function LinearAlgebra.mul!( ) where {T} # TODO: The reshape operations are not getting optimized, we should directly call dot_general rC = Ops.reshape(C, length(C), 1) - LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) + mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) return C end @@ -35,7 +206,7 @@ function LinearAlgebra.mul!( α::Number=true, β::Number=false, ) where {T} - LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) + mul!(C, A, reshape(B, :, 1), α, β) return C end @@ -146,9 +317,10 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} end function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} - return LinearAlgebra.diagm(length(v), length(v), v) + return diagm(length(v), length(v), v) end function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} + # TODO: Use scatter for this m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check v = materialize_traced_array(v) @@ -165,4 +337,6 @@ function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) wh ) end +simple_update_overwrite(x, y) = y + end From df22ad8286d45d68fba87631cfa858cd3d4fd5b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 19:44:46 +0530 Subject: [PATCH 02/15] refactor: use scatter for generating diagm --- src/stdlibs/LinearAlgebra.jl | 148 +++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 67 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index b5df5bb0f..46aa90bf0 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -35,65 +35,25 @@ function materialize_traced_array( end function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} - (_, update_function) = make_mlir_fn( - simple_update_overwrite, - (promote_to(TracedRNumber{T}, 0), promote_to(TracedRNumber{T}, 0)), - (), - string(gensym("update_computation")), - false; - return_dialect=:stablehlo, - no_args_in_result=true, - ) - update_computation = MLIR.IR.Region() - MLIR.API.mlirRegionTakeBody( - update_computation, MLIR.API.mlirOperationGetRegion(update_function, 0) - ) - MLIR.IR.rmfromparent!(update_function) - - init_array = Ops.constant(fill(zero(T), size(x))).mlir_data - - scatter_indices = Vector{Int64}[] - for i in (-1, 0, 1) - idxs = diagind(x, i, IndexCartesian()) - for idx in idxs - push!(scatter_indices, Int64[idx[1] - 1, idx[2] - 1]) - end - end - scatter_indices = - Ops.transpose(Ops.constant(reduce(hcat, scatter_indices)), [2, 1]).mlir_data - - updates = MLIR.IR.result( - MLIR.Dialects.stablehlo.concatenate( - [x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0 - ), - 1, - ) - - #! format: off - scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( - MLIR.IR.context(), - 0, Int64[], - 2, Int64[0, 1], - 0, Int64[], - 0, Int64[], - 2, Int64[0, 1], - 1 + scatter_indices = vcat( + diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1), + diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0), + diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1), ) - #! format: on + scatter_indices = Ops.constant(scatter_indices) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.scatter( - [init_array], - scatter_indices, - [updates]; - result_0=[mlir_type(TracedRArray{T,2}, size(x))], - update_computation, - scatter_dimension_numbers, + updates = TracedRArray{T,1}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.concatenate( + [x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0 + ), + 1, ), - 1, + (size(scatter_indices, 1),), ) - return TracedRArray{T,2}((), res, size(x)) + return simple_scatter_op(size(x), scatter_indices, updates) end for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) @@ -320,23 +280,77 @@ function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} return diagm(length(v), length(v), v) end function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} - # TODO: Use scatter for this m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check - v = materialize_traced_array(v) - D = length(v) - row_idxs = Ops.iota(Int, [D, D]; iota_dimension=1) - col_idxs = Ops.iota(Int, [D, D]; iota_dimension=2) - diag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="EQ") - - mat = (v .+ zero(v)') .* diag_indicator - return Ops.pad( - mat, - TracedUtils.promote_to(TracedRNumber{T}, 0); - high=[m - length(v), n - length(v)], - ) + indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :]) + return simple_scatter_op((m, n), indices, materialize_traced_array(v)) end simple_update_overwrite(x, y) = y +## This is quite handy to have but is not generalized enough to be put into Ops? Or maybe +## we can document it and place it there under a different name. It takes a list of values +## and a list of indices and constructs a matrix with the values at the indices. +function simple_scatter_op( + shape, scatter_indices::TracedRArray{Int64,2}, updates::TracedRArray{T,1} +) where {T} + @assert length(updates) == size(scatter_indices, 1) + @assert size(scatter_indices, 2) == 2 + + # TODO: Directly use `Ops.hlo_call` for this part + (_, update_function) = make_mlir_fn( + simple_update_overwrite, + (promote_to(TracedRNumber{T}, 0), promote_to(TracedRNumber{T}, 0)), + false; + return_dialect=:stablehlo, + no_args_in_result=true, + ) + update_computation = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody( + update_computation, MLIR.API.mlirOperationGetRegion(update_function, 0) + ) + MLIR.IR.rmfromparent!(update_function) + + init_array = Ops.constant(fill(zero(T), shape)).mlir_data + + #! format: off + scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( + MLIR.IR.context(), + 0, Int64[], + 2, Int64[0, 1], + 0, Int64[], + 0, Int64[], + 2, Int64[0, 1], + 1 + ) + #! format: on + + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.scatter( + [init_array], + scatter_indices.mlir_data, + [updates.mlir_data]; + result_0=[mlir_type(TracedRArray{T,2}, shape)], + update_computation, + scatter_dimension_numbers, + ), + 1, + ) + + return TracedRArray{T,2}((), res, shape) +end + +# The cartesian version doesn't exist in julia 1.10 +function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0) + Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k)) + Cstep = CartesianIndex(1, 1) + res = StepRangeLen(Cstart, Cstep, max(0, k <= 0 ? min(m + k, n) : min(m, n - k))) + indices = Matrix{Int}(undef, (length(res), 2)) + for (i, idx) in enumerate(res) + indices[i, 1] = idx[1] - 1 + indices[i, 2] = idx[2] - 1 + end + return indices +end + end From 0fe5e9efd93513a79900b5963d559eab2fdb1451 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 20:19:10 +0530 Subject: [PATCH 03/15] refactor: directly generate the region for simple_scatter_op --- src/stdlibs/LinearAlgebra.jl | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 46aa90bf0..9190a6e9c 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -281,13 +281,11 @@ function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} end function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check - indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :]) return simple_scatter_op((m, n), indices, materialize_traced_array(v)) end -simple_update_overwrite(x, y) = y - +# Common Utilities ## This is quite handy to have but is not generalized enough to be put into Ops? Or maybe ## we can document it and place it there under a different name. It takes a list of values ## and a list of indices and constructs a matrix with the values at the indices. @@ -297,19 +295,15 @@ function simple_scatter_op( @assert length(updates) == size(scatter_indices, 1) @assert size(scatter_indices, 2) == 2 - # TODO: Directly use `Ops.hlo_call` for this part - (_, update_function) = make_mlir_fn( - simple_update_overwrite, - (promote_to(TracedRNumber{T}, 0), promote_to(TracedRNumber{T}, 0)), - false; - return_dialect=:stablehlo, - no_args_in_result=true, - ) update_computation = MLIR.IR.Region() - MLIR.API.mlirRegionTakeBody( - update_computation, MLIR.API.mlirOperationGetRegion(update_function, 0) + block = MLIR.IR.Block( + [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], + [MLIR.IR.Location(), MLIR.IR.Location()], ) - MLIR.IR.rmfromparent!(update_function) + return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)]) + MLIR.IR.rmfromparent!(return_op) + push!(block, return_op) + pushfirst!(update_computation, block) init_array = Ops.constant(fill(zero(T), shape)).mlir_data From ef7e6362d470a5790f07f343af3e3a59bb293f96 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 20:31:06 +0530 Subject: [PATCH 04/15] feat: generalize diagm --- src/stdlibs/LinearAlgebra.jl | 61 +++++++++++++++--------------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 9190a6e9c..e9d9da892 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -34,26 +34,8 @@ function materialize_traced_array( return LinearAlgebra.diagm(parent(x)) end -function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} - scatter_indices = vcat( - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1), - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0), - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1), - ) - scatter_indices = Ops.constant(scatter_indices) - - updates = TracedRArray{T,1}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.concatenate( - [x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0 - ), - 1, - ), - (size(scatter_indices, 1),), - ) - - return simple_scatter_op(size(x), scatter_indices, updates) +function materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} + return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) end for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) @@ -276,13 +258,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} return TracedRArray{T,1}((), res, (diag_length,)) end -function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} - return diagm(length(v), length(v), v) -end -function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} - m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check - indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :]) - return simple_scatter_op((m, n), indices, materialize_traced_array(v)) +function LinearAlgebra._diagm( + shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}... +) where {T} + m, n = LinearAlgebra.diagm_size(shape, kv...) + scatter_indices = Matrix{Int64}[] + concat_inputs = MLIR.IR.Value[] + for (k, v) in kv + push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :]) + push!(concat_inputs, get_mlir_data(v)) + end + scatter_indices = Ops.constant(reduce(vcat, scatter_indices)) + values = TracedRArray{T,1}( + (), + MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), + (size(scatter_indices, 1),), + ) + return simple_scatter_op((m, n), scatter_indices, values) end # Common Utilities @@ -334,15 +326,14 @@ function simple_scatter_op( return TracedRArray{T,2}((), res, shape) end -# The cartesian version doesn't exist in julia 1.10 +## The cartesian version doesn't exist in julia 1.10 function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0) - Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k)) - Cstep = CartesianIndex(1, 1) - res = StepRangeLen(Cstart, Cstep, max(0, k <= 0 ? min(m + k, n) : min(m, n - k))) - indices = Matrix{Int}(undef, (length(res), 2)) - for (i, idx) in enumerate(res) - indices[i, 1] = idx[1] - 1 - indices[i, 2] = idx[2] - 1 + idx1, idx2 = 1 + max(0, -k), 1 + max(0, k) + L = max(0, k ≤ 0 ? min(m + k, n) : min(m, n - k)) + indices = Matrix{Int}(undef, (L, 2)) + for i in axes(indices, 1) + indices[i, 1] = idx1 + i - 2 + indices[i, 2] = idx2 + i - 2 end return indices end From 5216810edad98a1c0d2bc313cbc941cba483d9f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 22:56:01 +0530 Subject: [PATCH 05/15] feat: efficient non-contiguous setindex --- src/Ops.jl | 55 ++++++++++++++++++++++++++++++++++++ src/TracedRArray.jl | 48 ++++++++++++++++++++++--------- src/stdlibs/LinearAlgebra.jl | 52 ++-------------------------------- test/basic.jl | 20 +++++++++++++ 4 files changed, 112 insertions(+), 63 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 928dfbefe..802e2ff27 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1418,4 +1418,59 @@ julia> Reactant.@jit( end end +""" + scatter_setindex(dest, scatter_indices, updates) + +Uses [`MLIR.Dialects.stablehlo.scatter`](@ref) to set the values of `dest` at the indices +specified by `scatter_indices` to the values in `updates`. If the indices are contiguous it +is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref) +instead. +""" +function scatter_setindex( + dest::TracedRArray{T,N}, + scatter_indices::TracedRArray{Int64,2}, + updates::TracedRArray{T,1}, +) where {T,N} + @assert length(updates) == size(scatter_indices, 1) + @assert size(scatter_indices, 2) == N + + update_computation = MLIR.IR.Region() + block = MLIR.IR.Block( + [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], + [MLIR.IR.Location(), MLIR.IR.Location()], + ) + return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)]) + MLIR.IR.rmfromparent!(return_op) + push!(block, return_op) + pushfirst!(update_computation, block) + + #! format: off + scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( + MLIR.IR.context(), + 0, Int64[], + N, collect(Int64, 0:(N - 1)), + 0, Int64[], + 0, Int64[], + N, collect(Int64, 0:(N - 1)), + 1 + ) + #! format: on + + return TracedRArray{T,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.scatter( + [dest.mlir_data], + scatter_indices.mlir_data, + [updates.mlir_data]; + result_0=[mlir_type(TracedRArray{T,N}, size(dest))], + update_computation, + scatter_dimension_numbers, + ), + 1, + ), + size(dest), + ) +end + end # module Ops diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2f8c07eb3..262c87fab 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -67,11 +67,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} return i end - foreach(indices) do idxs - idxs isa Number && return nothing + for idxs in indices + idxs isa Number && continue contiguous = all(isone, diff(idxs)) # XXX: We want to throw error even for dynamic indexing - if typeof(a) <: Bool + if typeof(contiguous) <: Bool contiguous || error("non-contiguous indexing is not supported") end end @@ -99,16 +99,40 @@ function Base.getindex(a::WrappedTracedRArray, indices...) return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end -function Base.setindex!( - a::TracedRArray{T,N}, - v, - indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, -) where {T,N} +function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) - i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) + i isa Colon && return 1:size(a, idx) + i isa CartesianIndex && return Tuple(i) + return i end + + non_contiguous_setindex = false + for idxs in indices + idxs isa Number && continue + contiguous = all(isone, diff(idxs)) + # XXX: We want to throw error even for dynamic indexing + if typeof(contiguous) <: Bool && !contiguous + non_contiguous_setindex = true + break + end + end + + if non_contiguous_setindex + indices_tuples = collect(Iterators.product(indices...)) + indices = Matrix{Int}(undef, (length(indices_tuples), 2)) + for (i, idx) in enumerate(indices_tuples) + indices[i, 1] = idx[1] - 1 + indices[i, 2] = idx[2] - 1 + end + indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) + res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v))) + a.mlir_data = res.mlir_data + return v + end + v = TracedUtils.broadcast_to_size(v, length.(indices)) v = TracedUtils.promote_to(TracedRArray{T,N}, v) + indices = [ ( TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1 @@ -124,11 +148,7 @@ function Base.setindex!( return v end -function Base.setindex!( - a::AnyTracedRArray{T,N}, - v, - indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, -) where {T,N} +function Base.setindex!(a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...) setindex!(ancestor(a), v, ancestor_indices...) return a diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index e9d9da892..f1c7f1c9d 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -274,58 +274,12 @@ function LinearAlgebra._diagm( MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), (size(scatter_indices, 1),), ) - return simple_scatter_op((m, n), scatter_indices, values) -end - -# Common Utilities -## This is quite handy to have but is not generalized enough to be put into Ops? Or maybe -## we can document it and place it there under a different name. It takes a list of values -## and a list of indices and constructs a matrix with the values at the indices. -function simple_scatter_op( - shape, scatter_indices::TracedRArray{Int64,2}, updates::TracedRArray{T,1} -) where {T} - @assert length(updates) == size(scatter_indices, 1) - @assert size(scatter_indices, 2) == 2 - - update_computation = MLIR.IR.Region() - block = MLIR.IR.Block( - [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], - [MLIR.IR.Location(), MLIR.IR.Location()], - ) - return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)]) - MLIR.IR.rmfromparent!(return_op) - push!(block, return_op) - pushfirst!(update_computation, block) - - init_array = Ops.constant(fill(zero(T), shape)).mlir_data - - #! format: off - scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( - MLIR.IR.context(), - 0, Int64[], - 2, Int64[0, 1], - 0, Int64[], - 0, Int64[], - 2, Int64[0, 1], - 1 + return Ops.scatter_setindex( + Ops.constant(fill(zero(T), (m, n))), scatter_indices, values ) - #! format: on - - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.scatter( - [init_array], - scatter_indices.mlir_data, - [updates.mlir_data]; - result_0=[mlir_type(TracedRArray{T,2}, shape)], - update_computation, - scatter_dimension_numbers, - ), - 1, - ) - - return TracedRArray{T,2}((), res, shape) end +# Common Utilities ## The cartesian version doesn't exist in julia 1.10 function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0) idx1, idx2 = 1 + max(0, -k), 1 + max(0, k) diff --git a/test/basic.jl b/test/basic.jl index 8e97ed41d..ade01078b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -442,6 +442,26 @@ end @test @allowscalar all(isone, x_ra_array[4, :]) end +function non_contiguous_setindex!(x) + x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 + return x +end + +@testset "non-contiguous setindex!" begin + x = rand(6, 6) + x_ra = Reactant.to_rarray(x) + + y = @jit(non_contiguous_setindex!(x_ra)) + y = Array(y) + x_ra = Array(x_ra) + @test all(isone, y[1:3, 1:4]) + @test all(isone, x_ra[1:3, 1:4]) + @test !all(isone, y[4:end, :]) + @test !all(isone, x_ra[4:end, :]) + @test !all(isone, y[:, 5:end]) + @test !all(isone, x_ra[:, 5:end]) +end + tuple_byref(x) = (; a=(; b=x)) tuple_byref2(x) = abs2.(x), tuple_byref2(x) From 8c32d5116baaf8b29f4f97f0b1cab120ded10fd3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 23:22:08 +0530 Subject: [PATCH 06/15] fix: non-contiguous indexing is now supported --- src/Ops.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 2 +- src/TracedRArray.jl | 22 ++++++++++++++++++---- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 802e2ff27..ee870e80c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1473,4 +1473,47 @@ function scatter_setindex( ) end +""" + gather_getindex(src, gather_indices) + +Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices +specified by `gather_indices`. If the indices are contiguous it is recommended to directly +use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. +""" +function gather_getindex( + src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2} +) where {T,N} + @assert size(gather_indices, 2) == N + + #! format: off + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(1), Int64[1], + Int64(N - 1), collect(Int64, 0:(N - 2)), + Int64(0), Int64[], + Int64(0), Int64[], + Int64(N), collect(Int64, 0:(N - 1)), + 1 + ) + #! format: on + + return reshape( + TracedRArray{T,2}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.gather( + src.mlir_data, + gather_indices.mlir_data; + dimension_numbers, + slice_sizes=fill(Int64(1), N), + indices_are_sorted=false, + ), + 1, + ), + (size(gather_indices, 1), 1), + ), + size(gather_indices, 1), + ) +end + end # module Ops diff --git a/src/Reactant.jl b/src/Reactant.jl index 73cfb516b..b6185a84e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -105,7 +105,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} ) where {T,N} shape = Tuple(shape) if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == shape + @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" end return new{T,N}(paths, mlir_data, shape) end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 262c87fab..9e720205e 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -14,8 +14,9 @@ using ..Reactant: MLIR, ancestor, unwrapped_eltype +using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array + using ReactantCore: ReactantCore -using ..TracedUtils: TracedUtils, materialize_traced_array using GPUArraysCore: GPUArraysCore ReactantCore.is_traced(::TracedRArray) = true @@ -59,7 +60,6 @@ function Base.getindex(a::TracedRArray{T,0}) where {T} return TracedRNumber{T}((), a.mlir_data) end -# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) @@ -67,13 +67,27 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} return i end + non_contiguous_getindex = false for idxs in indices idxs isa Number && continue contiguous = all(isone, diff(idxs)) # XXX: We want to throw error even for dynamic indexing - if typeof(contiguous) <: Bool - contiguous || error("non-contiguous indexing is not supported") + if typeof(contiguous) <: Bool && !contiguous + non_contiguous_getindex = true + break + end + end + + if non_contiguous_getindex + indices_tuples = collect(Iterators.product(indices...)) + indices = Matrix{Int}(undef, (length(indices_tuples), 2)) + for (i, idx) in enumerate(indices_tuples) + indices[i, 1] = idx[1] - 1 + indices[i, 2] = idx[2] - 1 end + indices = promote_to(TracedRArray{Int,2}, indices) + res = Ops.gather_getindex(a, indices) + return Ops.reshape(res, size(indices_tuples)...) end start_indices = map(indices) do i From ecd6420f53145e1a04b1a299e511a360e195cd09 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Dec 2024 10:57:56 +0530 Subject: [PATCH 07/15] feat: implement set_mlir_data for the remaining types --- src/Reactant.jl | 14 +++++++---- src/TracedUtils.jl | 6 ++--- src/stdlibs/LinearAlgebra.jl | 45 ++++++++++++++++++++++++++++-------- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index b6185a84e..f8dd8008b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -123,11 +123,17 @@ const AnyTracedRMatrix{T} = Union{ } const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} -function TracedRArray(data::MLIR.IR.Value) +function TracedRArray{T}(data::MLIR.IR.Value) where {T} data_type = MLIR.IR.type(data) - return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) + if T == eltype(MLIR.IR.julia_type(data_type)) + return TracedRArray{T,ndims(data_type)}((), data, size(data_type)) + end + tdata = TracedRArray(data) + return Ops.convert(TracedRArray{T,ndims(data_type)}, tdata) +end + +function TracedRArray(data::MLIR.IR.Value) + return TracedRArray{eltype(MLIR.IR.julia_type(MLIR.IR.type(data)))}(data) end struct XLAArray{T,N} <: RArray{T,N} end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index a12c3cd10..02095b003 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -47,7 +47,7 @@ end function set_mlir_data!( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, data ) where {T,N,M} - res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + res_mlir_data = Ops.reshape(TracedRArray{T}(data), size(parent(x))...).mlir_data set_mlir_data!(parent(x), res_mlir_data) return x end @@ -55,12 +55,12 @@ end function set_mlir_data!( x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}, data ) where {T,N,perm,iperm} - parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data + parent(x).mlir_data = permutedims(TracedRArray{T}(data), iperm).mlir_data return x end function set_mlir_data!(x::AnyTracedRArray, data) - setindex!(x, TracedRArray(data), axes(x)...) + setindex!(x, TracedRArray{T}(data), axes(x)...) return x end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index f1c7f1c9d..99cdec1ee 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -84,7 +84,7 @@ end function TracedUtils.set_mlir_data!( x::Transpose{TracedRNumber{T},TracedRArray{T,N}}, data ) where {T,N} - tdata = TracedRArray(data) + tdata = TracedRArray{T}(data) px = parent(x) px.mlir_data = ( if ndims(px) == 1 @@ -99,7 +99,7 @@ end function TracedUtils.set_mlir_data!( x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data ) where {T,N} - tdata = TracedRArray(data) + tdata = TracedRArray{T}(data) px = parent(x) transposed_data = ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) @@ -108,18 +108,45 @@ function TracedUtils.set_mlir_data!( end function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T} - parent(x).mlir_data = diag(TracedRArray(data)).mlir_data + parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data return x end -# TODO: UnitLowerTriangular -# TODO: LowerTriangular -# TODO: UnitUpperTriangular -# TODO: UpperTriangular -# TODO: Symmetric +for (AT, dcomp, ocomp) in ( + (:LowerTriangular, "GE", "LT"), + (:UnitLowerTriangular, "GT", "LE"), + (:UpperTriangular, "LE", "GT"), + (:UnitUpperTriangular, "LT", "GE"), +) + @eval function set_mlir_data!( + x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data + ) where {T} + tdata = TracedRArray{T}(data) + z = zero(tdata) + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + data_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(dcomp)) + original_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(ocomp)) + res = Ops.add( + Ops.select(data_indicator, tdata, z), Ops.select(original_indicator, x.data, z) + ) + set_mlir_data!(x.data, res.mlir_data) + return x + end +end + +function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) where {T} + if x.uplo == 'L' + set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data) + else + set_mlir_data!(LinearAlgebra.UpperTriangular(parent(x)), data) + end + return x +end function set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} - tdata = TracedRArray(data) + tdata = TracedRArray{T}(data) set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) set_mlir_data!(x.d, diag(tdata, 0).mlir_data) set_mlir_data!(x.du, diag(tdata, 1).mlir_data) From 24425fffb2f0786efe5f28baf5053f7e3c98f439 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Dec 2024 11:01:03 +0530 Subject: [PATCH 08/15] refactor: use `Ops.gather_getindex` to implement diag --- src/stdlibs/LinearAlgebra.jl | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 99cdec1ee..c9e2455c7 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -259,30 +259,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,2}, indices)) - - #! format: off - dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( - MLIR.IR.context(), - Int64(0), Int64[], - Int64(2), Int64[0, 1], - Int64(0), Int64[], - Int64(0), Int64[], - Int64(2), Int64[0, 1], - Int64(1) - ) - #! format: on - - slice_sizes = get_mlir_data( - Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1]) - ) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dynamic_gather( - get_mlir_data(y), idxs, slice_sizes; dimension_numbers - ), - 1, - ) - return TracedRArray{T,1}((), res, (diag_length,)) + return Ops.gather_getindex(x, promote_to(TracedRArray{Int,2}, indices)) end function LinearAlgebra._diagm( From b811c074c0360ddcd8574d8d268d6fda37211949 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 19 Dec 2024 11:49:43 +0530 Subject: [PATCH 09/15] fix: noinline ops --- src/Ops.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index ee870e80c..a8b1d82a2 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1426,7 +1426,7 @@ specified by `scatter_indices` to the values in `updates`. If the indices are co is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref) instead. """ -function scatter_setindex( +@noinline function scatter_setindex( dest::TracedRArray{T,N}, scatter_indices::TracedRArray{Int64,2}, updates::TracedRArray{T,1}, @@ -1480,7 +1480,7 @@ Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the specified by `gather_indices`. If the indices are contiguous it is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. """ -function gather_getindex( +@noinline function gather_getindex( src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2} ) where {T,N} @assert size(gather_indices, 2) == N From 31ee2ffedd0baa602c2fcb4a5bdfe511f529fd27 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 19 Dec 2024 11:53:00 +0530 Subject: [PATCH 10/15] fix: incorrect rebase --- src/Ops.jl | 12 ++++++------ src/TracedUtils.jl | 2 +- src/stdlibs/LinearAlgebra.jl | 12 +++++++----- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index a8b1d82a2..7c41453f8 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1447,12 +1447,12 @@ instead. #! format: off scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( MLIR.IR.context(), - 0, Int64[], - N, collect(Int64, 0:(N - 1)), - 0, Int64[], - 0, Int64[], - N, collect(Int64, 0:(N - 1)), - 1 + Int64(0), Int64[], + Int64(N), collect(Int64, 0:(N - 1)), + Int64(0), Int64[], + Int64(0), Int64[], + Int64(N), collect(Int64, 0:(N - 1)), + Int64(1) ) #! format: on diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 02095b003..7b491f4b7 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -59,7 +59,7 @@ function set_mlir_data!( return x end -function set_mlir_data!(x::AnyTracedRArray, data) +function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T} setindex!(x, TracedRArray{T}(data), axes(x)...) return x end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index c9e2455c7..457c73a43 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -34,7 +34,7 @@ function materialize_traced_array( return LinearAlgebra.diagm(parent(x)) end -function materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} +function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) end @@ -118,7 +118,7 @@ for (AT, dcomp, ocomp) in ( (:UpperTriangular, "LE", "GT"), (:UnitUpperTriangular, "LT", "GE"), ) - @eval function set_mlir_data!( + @eval function TracedUtils.set_mlir_data!( x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data ) where {T} tdata = TracedRArray{T}(data) @@ -136,7 +136,9 @@ for (AT, dcomp, ocomp) in ( end end -function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) where {T} +function TracedUtils.set_mlir_data!( + x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data +) where {T} if x.uplo == 'L' set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data) else @@ -145,7 +147,7 @@ function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) w return x end -function set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} +function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} tdata = TracedRArray{T}(data) set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) set_mlir_data!(x.d, diag(tdata, 0).mlir_data) @@ -259,7 +261,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - return Ops.gather_getindex(x, promote_to(TracedRArray{Int,2}, indices)) + return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices)) end function LinearAlgebra._diagm( From c536e5708d3540a243be7ca1d8c09139ad4dd718 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 28 Dec 2024 22:00:46 -0500 Subject: [PATCH 11/15] fix: dispatches --- src/stdlibs/LinearAlgebra.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 457c73a43..4b9ca80aa 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -14,7 +14,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_m using LinearAlgebra # Various Wrapper Arrays defined in LinearAlgebra -function materialize_traced_array( +function TracedUtils.materialize_traced_array( x::Transpose{TracedRNumber{T},TracedRArray{T,N}} ) where {T,N} px = parent(x) @@ -22,16 +22,16 @@ function materialize_traced_array( return permutedims(A, (2, 1)) end -function materialize_traced_array( +function TracedUtils.materialize_traced_array( x::Adjoint{TracedRNumber{T},TracedRArray{T,N}} ) where {T,N} return conj(materialize_traced_array(transpose(parent(x)))) end -function materialize_traced_array( - x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}} +function TracedUtils.materialize_traced_array( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}} ) where {T} - return LinearAlgebra.diagm(parent(x)) + return diagm(parent(x)) end function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} @@ -42,7 +42,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) uAT = Symbol(:Unit, AT) @eval begin function TracedUtils.materialize_traced_array( - x::$(AT){T,TracedRArray{T,2}} + x::$(AT){TracedRNumber{T},TracedRArray{T,2}} ) where {T} m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) @@ -52,7 +52,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) end function TracedUtils.materialize_traced_array( - x::$(uAT){T,TracedRArray{T,2}} + x::$(uAT){TracedRNumber{T},TracedRArray{T,2}} ) where {T} m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) @@ -64,7 +64,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) end end -function TracedUtils.materialize_traced_array(x::Symmetric{T,TracedRArray{T,2}}) where {T} +function TracedUtils.materialize_traced_array( + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}} +) where {T} m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) @@ -107,7 +109,9 @@ function TracedUtils.set_mlir_data!( return x end -function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T} +function TracedUtils.set_mlir_data!( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data return x end @@ -119,7 +123,7 @@ for (AT, dcomp, ocomp) in ( (:UnitUpperTriangular, "LT", "GE"), ) @eval function TracedUtils.set_mlir_data!( - x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data + x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data ) where {T} tdata = TracedRArray{T}(data) z = zero(tdata) @@ -137,17 +141,19 @@ for (AT, dcomp, ocomp) in ( end function TracedUtils.set_mlir_data!( - x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data ) where {T} if x.uplo == 'L' - set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data) + set_mlir_data!(LowerTriangular(parent(x)), data) else - set_mlir_data!(LinearAlgebra.UpperTriangular(parent(x)), data) + set_mlir_data!(UpperTriangular(parent(x)), data) end return x end -function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} +function TracedUtils.set_mlir_data!( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} tdata = TracedRArray{T}(data) set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) set_mlir_data!(x.d, diag(tdata, 0).mlir_data) From 1affd14205c6ce7f0ea8285da43417ddf2f09c93 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Dec 2024 09:32:54 -0500 Subject: [PATCH 12/15] fix: diagm for repeated indices and initial tests --- src/stdlibs/LinearAlgebra.jl | 13 ++++++++++++- test/integration/linear_algebra.jl | 11 +++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 4b9ca80aa..95369e2a2 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -274,9 +274,20 @@ function LinearAlgebra._diagm( shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}... ) where {T} m, n = LinearAlgebra.diagm_size(shape, kv...) + + # For repeated indices we need to aggregate the values + kv_updated = Dict{Integer,AnyTracedRArray{T,1}}() + for (k, v) in kv + if haskey(kv_updated, k) + kv_updated[k] = kv_updated[k] + v + else + kv_updated[k] = v + end + end + scatter_indices = Matrix{Int64}[] concat_inputs = MLIR.IR.Value[] - for (k, v) in kv + for (k, v) in pairs(kv_updated) push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :]) push!(concat_inputs, get_mlir_data(v)) end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 0c6efc5fd..0aac0f11c 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -130,6 +130,17 @@ end @test @jit(diagm(4, 5, x_ra)) ≈ diagm(4, 5, x) @test @jit(diagm(6, 6, x_ra)) ≈ diagm(6, 6, x) @test_throws DimensionMismatch @jit(diagm(3, 3, x_ra)) + + x1 = rand(3) + x2 = rand(3) + x3 = rand(2) + x_ra1 = Reactant.to_rarray(x1) + x_ra2 = Reactant.to_rarray(x2) + x_ra3 = Reactant.to_rarray(x3) + + @test @jit(diagm(1 => x_ra1)) ≈ diagm(1 => x1) + @test @jit(diagm(1 => x_ra1, -1 => x_ra3)) ≈ diagm(1 => x1, -1 => x3) + @test @jit(diagm(1 => x_ra1, 1 => x_ra2)) ≈ diagm(1 => x1, 1 => x2) end # TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be From b60ca6cc6de8a6d9476bde3a72ef0fb31cd0b647 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Dec 2024 11:27:36 -0500 Subject: [PATCH 13/15] fix: higher dimensional indexing + tests --- src/Ops.jl | 42 ++++++++++++++++++++++------------- src/TracedRArray.jl | 20 ++++++++--------- test/basic.jl | 54 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 25 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 7c41453f8..f67300787 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1445,14 +1445,21 @@ instead. pushfirst!(update_computation, block) #! format: off + update_window_dims = Int64[] + inserted_window_dims = collect(Int64, 0:(N - 1)) + input_batching_dims = Int64[] + scatter_indices_batching_dims = Int64[] + scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1)) + index_vector_dim = Int64(1) + scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( MLIR.IR.context(), - Int64(0), Int64[], - Int64(N), collect(Int64, 0:(N - 1)), - Int64(0), Int64[], - Int64(0), Int64[], - Int64(N), collect(Int64, 0:(N - 1)), - Int64(1) + length(update_window_dims), update_window_dims, + length(inserted_window_dims), inserted_window_dims, + length(input_batching_dims), input_batching_dims, + length(scatter_indices_batching_dims), scatter_indices_batching_dims, + length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims, + index_vector_dim, ) #! format: on @@ -1486,20 +1493,26 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. @assert size(gather_indices, 2) == N #! format: off + offset_dims = Int64[1] + collapsed_slice_dims = collect(Int64, 0:(N - 2)) + operand_batching_dims = Int64[] + start_indices_batching_dims = Int64[] + start_index_map = collect(Int64, 0:(N - 1)) + index_vector_dim = Int64(1) + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( MLIR.IR.context(), - Int64(1), Int64[1], - Int64(N - 1), collect(Int64, 0:(N - 2)), - Int64(0), Int64[], - Int64(0), Int64[], - Int64(N), collect(Int64, 0:(N - 1)), - 1 + Int64(length(offset_dims)), offset_dims, + Int64(length(collapsed_slice_dims)), collapsed_slice_dims, + Int64(length(operand_batching_dims)), operand_batching_dims, + Int64(length(start_indices_batching_dims)), start_indices_batching_dims, + Int64(length(start_index_map)), start_index_map, + Int64(index_vector_dim), ) #! format: on return reshape( - TracedRArray{T,2}( - (), + TracedRArray{T}( MLIR.IR.result( MLIR.Dialects.stablehlo.gather( src.mlir_data, @@ -1510,7 +1523,6 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. ), 1, ), - (size(gather_indices, 1), 1), ), size(gather_indices, 1), ) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 9e720205e..275f6dd92 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -56,9 +56,7 @@ function Base.getindex( return TracedRNumber{T}((), res2) end -function Base.getindex(a::TracedRArray{T,0}) where {T} - return TracedRNumber{T}((), a.mlir_data) -end +Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data) function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) @@ -80,12 +78,13 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} if non_contiguous_getindex indices_tuples = collect(Iterators.product(indices...)) - indices = Matrix{Int}(undef, (length(indices_tuples), 2)) + indices = Matrix{Int}( + undef, (length(indices_tuples), length(first(indices_tuples))) + ) for (i, idx) in enumerate(indices_tuples) - indices[i, 1] = idx[1] - 1 - indices[i, 2] = idx[2] - 1 + indices[i, :] .= idx .- 1 end - indices = promote_to(TracedRArray{Int,2}, indices) + indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) res = Ops.gather_getindex(a, indices) return Ops.reshape(res, size(indices_tuples)...) end @@ -133,10 +132,11 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { if non_contiguous_setindex indices_tuples = collect(Iterators.product(indices...)) - indices = Matrix{Int}(undef, (length(indices_tuples), 2)) + indices = Matrix{Int}( + undef, (length(indices_tuples), length(first(indices_tuples))) + ) for (i, idx) in enumerate(indices_tuples) - indices[i, 1] = idx[1] - 1 - indices[i, 2] = idx[2] - 1 + indices[i, :] .= idx .- 1 end indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v))) diff --git a/test/basic.jl b/test/basic.jl index ade01078b..3522cd59e 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -737,3 +737,57 @@ end @test res[1] isa ConcreteRArray{Float64,2} @test res[2] isa ConcreteRNumber{Float64} end + +@testset "non-contiguous indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] + non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :] + non_contiguous_indexing2(x) = x[:, [1, 2, 2]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x +end From 89fe3d4fcb9c97dd2be9a153950bd7143f12ec0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Dec 2024 12:10:12 -0500 Subject: [PATCH 14/15] fix: matrix multiplication of wrapper types --- src/Overlay.jl | 25 +++++++++++++++++++++++++ src/Reactant.jl | 4 +++- src/stdlibs/LinearAlgebra.jl | 14 ++++++++------ test/integration/linear_algebra.jl | 30 +++++++++++++++++++++++------- 4 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index b9785b7fa..701d9c711 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -115,3 +115,28 @@ for randfun in (:rand, :randn, :randexp) # end end end + +# LinearAlgebra.jl overloads +## `_mul!` goes through too many layers of abstractions and we aren't able to overload +## without specializing on every possible combination of types +@reactant_overlay @noinline function LinearAlgebra.mul!( + C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number +) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C +end + +@reactant_overlay @noinline function LinearAlgebra.mul!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number +) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C +end diff --git a/src/Reactant.jl b/src/Reactant.jl index f8dd8008b..d06784c13 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -119,7 +119,9 @@ const WrappedTracedRArray{T,N} = WrappedArray{ const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = Union{ - AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} + AnyTracedRArray{T,2}, + LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, + LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, } const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 95369e2a2..aa56c7b92 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -34,7 +34,9 @@ function TracedUtils.materialize_traced_array( return diagm(parent(x)) end -function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} +function TracedUtils.materialize_traced_array( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) end @@ -162,7 +164,7 @@ function TracedUtils.set_mlir_data!( end # Core functions -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRVector), @@ -171,23 +173,23 @@ function LinearAlgebra.mul!( ) where {T} # TODO: The reshape operations are not getting optimized, we should directly call dot_general rC = Ops.reshape(C, length(C), 1) - mul!(rC, A, reshape(B, :, 1), α, β) + overloaded_mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRVector), α::Number=true, β::Number=false, ) where {T} - mul!(C, A, reshape(B, :, 1), α, β) + overloaded_mul!(C, A, reshape(B, :, 1), α, β) return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRMatrix), diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 0aac0f11c..ea39556f9 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -1,4 +1,4 @@ -using LinearAlgebra, Reactant +using LinearAlgebra, Reactant, Test function muladd2(A, x, b) C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2)) @@ -143,13 +143,29 @@ end @test @jit(diagm(1 => x_ra1, 1 => x_ra2)) ≈ diagm(1 => x1, 1 => x2) end -# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be -# optimized +# TODO: Currently (x) * x goes down the generic matmul path but it should +# clearly be optimized mul_diagonal(x) = Diagonal(x) * x - -@testset "mul_diagonal" begin - x = rand(4) +mul_tridiagonal(x) = Tridiagonal(x) * x +mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x +mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x +mul_lower_triangular(x) = LowerTriangular(x) * x +mul_upper_triangular(x) = UpperTriangular(x) * x +mul_symmetric(x) = Symmetric(x) * x + +@testset "Wrapper Types Matrix Multiplication" begin + x = rand(4, 4) x_ra = Reactant.to_rarray(x) - @test @jit(mul_diagonal(x_ra)) ≈ mul_diagonal(x) + @testset "$(wrapper_type)" for (wrapper_type, fn) in [ + (Diagonal, mul_diagonal), + (Tridiagonal, mul_tridiagonal), + (UnitLowerTriangular, mul_unit_lower_triangular), + (UnitUpperTriangular, mul_unit_upper_triangular), + (LowerTriangular, mul_lower_triangular), + (UpperTriangular, mul_upper_triangular), + (Symmetric, mul_symmetric), + ] + @test @jit(fn(x_ra)) ≈ fn(x) + end end From 51c0d6e5300dad8928f7129363ed99dd98414812 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Dec 2024 12:41:15 -0500 Subject: [PATCH 15/15] fix: de-specialize 3 arg mul! --- src/Compiler.jl | 2 +- src/Overlay.jl | 36 +++++++++++++++++++----------------- test/wrapped_arrays.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 62611047b..3c4c3996d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -105,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic end function create_result( - tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol}, + tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, path, result_stores, ) diff --git a/src/Overlay.jl b/src/Overlay.jl index 701d9c711..0b5844464 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -119,24 +119,26 @@ end # LinearAlgebra.jl overloads ## `_mul!` goes through too many layers of abstractions and we aren't able to overload ## without specializing on every possible combination of types -@reactant_overlay @noinline function LinearAlgebra.mul!( - C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number +for (cT, aT, bT) in ( + (:AbstractVector, :AbstractMatrix, :AbstractVector), + (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), ) - if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) - else - LinearAlgebra._mul!(C, A, B, α, β) - end - return C -end + @eval begin + @reactant_overlay @noinline function LinearAlgebra.mul!( + C::$cT, A::$aT, B::$bT, α::Number, β::Number + ) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C + end -@reactant_overlay @noinline function LinearAlgebra.mul!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number -) - if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) - else - LinearAlgebra._mul!(C, A, B, α, β) + # Needed mostly for 1.10 where 3-arg mul is often specialized + @reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT) + call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) + return C + end end - return C end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f5418e5c8..c522bcd17 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -172,3 +172,33 @@ end @test all(iszero, y_res) end end + +function lower_triangular_write(x) + y = LowerTriangular(copy(x)) + @. y *= 2 + return y +end + +function upper_triangular_write(x) + y = UpperTriangular(copy(x)) + @. y *= 2 + return y +end + +function tridiagonal_write(x) + y = Tridiagonal(copy(x)) + @. y *= 2 + return y +end + +@testset "Broadcasted Multiply and Alloate" begin + @testset "$(aType)" for (aType, fn) in [ + ("LowerTriangular", lower_triangular_write), + ("UpperTriangular", upper_triangular_write), + ("Tridiagonal", tridiagonal_write), + ] + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + @test @jit(fn(x_ra)) ≈ fn(x) + end +end