-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ops.hlo_call(::String, args...) #358
Conversation
idea: we could hash the code and just the call the function if a function with the same hash is already present in the module (preventing duplicate functions). So that: for i in 1:N
x = Ops.hlo_call(code, x)
end has the function just once (but multiple calls). Edit: implemented this by using the code hash in the function names. |
that can be done directly with a get!(dict, code) do code
compile(code)
end |
ah right, do we have a place to store such state within a single trace ? Otherwise my idea is to name these function like |
I have a question... What is |
No, it is returning a tuple of tracedarray returned by the call to the function. |
src/Ops.jl
Outdated
function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__)) | ||
new_mod = parse(MLIR.IR.Module, code) | ||
body = MLIR.IR.body(new_mod) | ||
fn = MLIR.IR.first_op(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the first op is not main
? like what if we the code was traced by us and we added some function barriers.
maybe we can add a kwarg for selecting the target function (and default it to main
), so we just iterate over the ops doing first(Iterators.filter(op -> String(IR.attr(op, "sym_name")) == target_fn, OperationIterator(body))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is inside the code that was given by the caller. Currently, the expectation is that there is only one function inside the given module. We can surely revisit that with a keyword indeed, or a tuple as for Core.llvmcall
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I would instead ideally have a kwargument fn=main, and we extract that fn as the top level one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a func_name::String
kwarg for this
Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
MLIR.IR.rmfromparent!(fn) | ||
|
||
current_module = MLIR.IR.mmodule() | ||
top_level_block = MLIR.IR.body(current_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need to mark all fn's as private, as well as make sure to move all fns in the module (e.g. the main function could call something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have some utilities here: https://github.com/EnzymeAD/Enzyme-JAX/blob/f6587e37ff7298f2a1a273b08c24d69fca7ff30f/src/enzyme_ad/jax/compile_with_xla.cc#L190 and https://github.com/EnzymeAD/Enzyme-JAX/blob/f6587e37ff7298f2a1a273b08c24d69fca7ff30f/src/enzyme_ad/jax/primitives.py#L811 in Enzyme-JaX for explicitly making we can do all the things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, do you know if we can encounter ops other than func.func
(maybe gpu.func
in the future?) and what to do with them ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it’s fine to assume func for now but if desired we could generalize to function interface or whatnot
src/Ops.jl
Outdated
new_mod = parse(MLIR.IR.Module, code) | ||
body = MLIR.IR.body(new_mod) | ||
|
||
for op in MLIR.IR.OperationIterator(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh we should import all the ops, not just func, but it’s okay to be limited to just func as the entry function. Eg if main calls a gpu function that would be fine. Or if a global constant op
y_reactant = Reactant.to_rarray(y) | ||
|
||
@test Reactant.@jit( | ||
Ops.hlo_call( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test with multiple functions in the module.
and can you also add a test with two (different) hlo calls that happen to contain functions of the same name (to make sure we do the symbol rename properly)
So hlo_call currently returns a tuple of arrays (one for each result), should we special case |
x, = Ops.hlo_call( | ||
""" | ||
module { | ||
func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a version of this where the two definitions are different.
just because if we fix caching then we might not actually not emit it twice (and thus not check things)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test with the same name but different definitions:
Lines 945 to 970 in 8eb71cc
function f_multiple_hlo_calls(x, y) | |
x, = Ops.hlo_call( | |
""" | |
module { | |
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | |
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> | |
return %0 : tensor<3xf32> | |
} | |
} | |
""", | |
x, | |
y, | |
) | |
return Ops.hlo_call( | |
""" | |
module { | |
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { | |
%0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32> | |
return %0 : tensor<3xf32> | |
} | |
} | |
""", | |
x, | |
y, | |
) | |
end |
Honestly I think it’s better to always return the tuple. That way folks using it don’t need to special case if there are multiple returns or not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff!
this is amazing and opens a way to do things like #354 |
Other potential applications:
Strings really are the universal model format (we can add |
No description provided.