Skip to content

Commit

Permalink
Merge 5ca367d into 1e6037f
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx authored Feb 1, 2025
2 parents 1e6037f + 5ca367d commit 4aa1d68
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 20 deletions.
72 changes: 62 additions & 10 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,62 @@ end
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T<:Number}
x isa TracedRNumber && return x
res = constant(fill(x); location)
res = fill(x; location)
return TracedRNumber{T}((), res.mlir_data)
end

fill(v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)) = fill(v, dims; location)
function fill(v, dims::NTuple{N,Union{Integer,Base.OneTo}}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)) where {N}
return fill(v, map(Base.to_dim, dims); location)
end
fill(v, dims::NTuple{N,Integer}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)) where {N} = fill(v, collect(dims); location)
fill(v, ::Tuple{}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)) = fill(v, Int[]; location)

for (T, mlir_func) in (
(Bool, :mlirDenseElementsAttrBoolSplatGet),
(UInt8, :mlirDenseElementsAttrUInt8SplatGet),
(Int8, :mlirDenseElementsAttrInt8SplatGet),
(UInt32, :mlirDenseElementsAttrUInt32SplatGet),
(Int32, :mlirDenseElementsAttrInt32SplatGet),
(UInt64, :mlirDenseElementsAttrUInt64SplatGet),
(Int64, :mlirDenseElementsAttrInt64SplatGet),
(Float32, :mlirDenseElementsAttrFloatSplatGet),
(Float64, :mlirDenseElementsAttrDoubleSplatGet),
)
@eval begin
@noinline function fill(
number::$T,
shape::Vector{Int};
location=mlir_stacktrace("fill", @__FILE__, @__LINE__),
)
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type($T); location=location)

splatattr = MLIR.API.$mlir_func(tt, number)
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
cst = MLIR.IR.result(cst_op)
ta = TracedRArray{$T,length(shape)}((), cst, shape)
return ta
end
end
end

_fill_element_attr(x) = MLIR.IR.Attribute(x)
_fill_element_attr(x::Complex) = MLIR.IR.Attribute([
MLIR.IR.Attribute(Base.real(x)),
MLIR.IR.Attribute(Base.imag(x)),
])

@noinline function fill(
element::T, shape::Vector{Int}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
) where {T}
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element))
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
cst = MLIR.IR.result(cst_op)
ta = TracedRArray{T,length(shape)}((), cst, shape)
return ta
end

# unary elementwise ops
for (dialect, op) in [
(:stablehlo, :abs),
Expand Down Expand Up @@ -350,9 +402,9 @@ end
@noinline function pad(
x::TracedRArray{T,N},
padding_value::TracedRNumber{T};
low=fill(0, N),
high=fill(0, N),
interior=fill(0, N),
low=Base.fill(0, N),
high=Base.fill(0, N),
interior=Base.fill(0, N),
location=mlir_stacktrace("pad", @__FILE__, @__LINE__),
) where {T,N}
rsize = size(x) .+ low .+ high .+ max.(size(x) .- 1, 0) .* interior
Expand Down Expand Up @@ -1056,7 +1108,7 @@ end
op = chlo.top_k(x.mlir_data; values, indices, k, location)
indices = add(
TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
constant(fill(Int32(1), Tuple(rsize))),
fill(Int32(1), Tuple(rsize)),
) # return the 1-indexed index
indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally
values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize)
Expand Down Expand Up @@ -1160,7 +1212,7 @@ end
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
output = divide(
convert(TracedRArray{T,ndims(output)}, output),
constant(fill(T(typemax(uT)), Tuple(shape)); location),
fill(T(typemax(uT)), Tuple(shape); location),
)
return (; output_state, output)
end
Expand Down Expand Up @@ -1200,11 +1252,11 @@ fields:
rand_uniform = res.output
seed = res.output_state
scaled_uniform = subtract(
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
constant(fill(T(1), size(rand_uniform))),
multiply(rand_uniform, fill(T(2), size(rand_uniform))),
fill(T(1), size(rand_uniform)),
)
probit = erf_inv(scaled_uniform)
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
rand_normal = multiply(probit, fill(Base.sqrt(T(2)), size(rand_uniform)))
return (; output_state=seed, output=rand_normal)
end

Expand Down Expand Up @@ -1570,7 +1622,7 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
src.mlir_data,
gather_indices.mlir_data;
dimension_numbers,
slice_sizes=fill(Int64(1), N),
slice_sizes=Base.fill(Int64(1), N),
indices_are_sorted=false,
),
1,
Expand Down
15 changes: 7 additions & 8 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,8 @@ Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct?

Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))

# TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`?
function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N}
return Ops.constant(zeros(unwrapped_eltype(T), dims))
return Ops.fill(zero(unwrapped_eltype(T)), dims)
end

function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}}
Expand Down Expand Up @@ -992,12 +991,12 @@ function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
# Compute linear indices
strds = strides(x)
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
iotas[dims] = Ops.subtract(indices, Ops.fill(Int64(1), size(indices)))
linear_indices = Ops.fill(Int64(1), size(indices))
for d in eachindex(iotas)
linear_indices = Ops.add(
linear_indices,
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
Ops.multiply(iotas[d], Ops.fill(Int64(strds[d]), size(iotas[d]))),
)
end

Expand All @@ -1021,12 +1020,12 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
# Compute linear indices
strds = strides(x)
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
iotas[dims] = Ops.subtract(indices, Ops.fill(Int64(1), size(indices)))
linear_indices = Ops.fill(Int64(1), size(indices))
for d in eachindex(iotas)
linear_indices = Ops.add(
linear_indices,
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
Ops.multiply(iotas[d], Ops.fill(Int64(strds[d]), size(iotas[d]))),
)
end

Expand Down
2 changes: 1 addition & 1 deletion src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
)
end
rhs isa Number &&
return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs))))
return TracedUtils.promote_to(TracedRNumber{T}, Ops.fill(T(rhs)))
return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(collect(rhs)))
end

Expand Down
2 changes: 1 addition & 1 deletion src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ function LinearAlgebra._diagm(
(size(scatter_indices, 1),),
)
return Ops.scatter_setindex(
Ops.constant(fill(zero(T), (m, n))), scatter_indices, values
Ops.fill(zero(T), (m, n)), scatter_indices, values
)
end

Expand Down

0 comments on commit 4aa1d68

Please sign in to comment.