-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing support for func.call in mlir #50
Comments
Add this to
The issue here is that ArithToStableHLOPass is run before StableHLOToTTIRPass. It doesn't mark |
Also add |
Will do. I'm interested what do we do after this gets fixed. Do we go with module @jit_module_remainder attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<3x3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.remainder %arg0, %arg1 : tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}
} or module @jit_module_remainder attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<3x3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @remainder(%arg0, %arg1) : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}
func.func private @remainder(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.remainder %arg0, %arg1 : tensor<3x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
%2 = stablehlo.compare NE, %0, %1, FLOAT : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xi1>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%3 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
%4 = stablehlo.compare LT, %0, %3, FLOAT : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xi1>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%5 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
%6 = stablehlo.compare LT, %arg1, %5, FLOAT : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xi1>
%7 = stablehlo.compare NE, %4, %6, UNSIGNED : (tensor<3x3xi1>, tensor<3x3xi1>) -> tensor<3x3xi1>
%8 = stablehlo.and %7, %2 : tensor<3x3xi1>
%9 = stablehlo.add %0, %arg1 : tensor<3x3xf32>
%10 = stablehlo.select %8, %9, %0 : tensor<3x3xi1>, tensor<3x3xf32>
return %10 : tensor<3x3xf32>
}
} in order to test def test_remainder_op():
def module_remainder(a, b):
return jnp.remainder(a, b) # or return jax.lax.rem(a, b)
verify_module(module_remainder, [(3, 3), (3, 3)])
verify_module(module_remainder, [(3, 3, 3), (3, 3, 3)]) @uazizTT @LPanosTT ? Do we need/want this complication with implicit broadcast or is it needless for such a simple test to go through that path? |
@kmitrovicTT I thin we should just support both cases. Models may end up using either. What do all those extra ops do anyway? |
I guess by looking at it some kind of implicit broadcast which in this case is no-op if shapes match. Might be that I ask this because I see some other tests resorted to |
Well I would say just allow it. The function call will get inlined early on in the ttir to ttnn pass. @AleksKnezevic Thoughts? |
@LPanosTT I don't follow your comment, allow what? |
Support both cases. |
Got it, yeah, I agree, both should be supported. |
While working on tenstorrent/tt-mlir#1126 I wrote
jax
testwhich resulted in
Using
jax.lax.rem
instead passed right away.Issue is that in tt-mlir we don't support nested mlir func calls. Attached are two stable HLO graphs generated by running python test, one using
jnp.remainder
and the other usingjax.lax.rem
. Issue is injnp_graph.mlir.txt
lax_graph.mlir.txt
The text was updated successfully, but these errors were encountered: