Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing support for func.call in mlir #50

Closed
kmitrovicTT opened this issue Oct 31, 2024 · 9 comments · Fixed by tenstorrent/tt-mlir#1157
Closed

Missing support for func.call in mlir #50

kmitrovicTT opened this issue Oct 31, 2024 · 9 comments · Fixed by tenstorrent/tt-mlir#1157
Assignees

Comments

@kmitrovicTT
Copy link
Contributor

While working on tenstorrent/tt-mlir#1126 I wrote jax test

def test_remainder_op():
    def module_remainder(a, b):
        return jnp.remainder(a, b)

    verify_module(module_remainder, [(3, 3), (3, 3)])
    verify_module(module_remainder, [(3, 3, 3), (3, 3, 3)])

which resulted in

image

Using jax.lax.rem instead passed right away.

Issue is that in tt-mlir we don't support nested mlir func calls. Attached are two stable HLO graphs generated by running python test, one using jnp.remainder and the other using jax.lax.rem. Issue is in

image

jnp_graph.mlir.txt
lax_graph.mlir.txt

@LPanosTT
Copy link
Contributor

LPanosTT commented Nov 1, 2024

Add this to ArithToStableHLOPass.cpp:

target.addLegalOp<mlir::func::CallOp>();

The issue here is that ArithToStableHLOPass is run before StableHLOToTTIRPass. It doesn't mark func.call as legal though so it will fail.

@uazizTT
Copy link
Contributor

uazizTT commented Nov 1, 2024

Also add target.addLegalOp<mlir::func::CallOp>(); to StableHLOToTTIRPass

@LPanosTT LPanosTT assigned kmitrovicTT and unassigned LPanosTT Nov 1, 2024
@kmitrovicTT
Copy link
Contributor Author

kmitrovicTT commented Nov 1, 2024

Will do.

I'm interested what do we do after this gets fixed. Do we go with jax.lax.rem which produces trivial graph

module @jit_module_remainder attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<3x3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.remainder %arg0, %arg1 : tensor<3x3xf32>
    return %0 : tensor<3x3xf32>
  }
}

or jnp.remainder which produces

module @jit_module_remainder attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<3x3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @remainder(%arg0, %arg1) : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
    return %0 : tensor<3x3xf32>
  }
  func.func private @remainder(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.remainder %arg0, %arg1 : tensor<3x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
    %2 = stablehlo.compare  NE, %0, %1,  FLOAT : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xi1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
    %4 = stablehlo.compare  LT, %0, %3,  FLOAT : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xi1>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %5 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<3x3xf32>
    %6 = stablehlo.compare  LT, %arg1, %5,  FLOAT : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xi1>
    %7 = stablehlo.compare  NE, %4, %6,  UNSIGNED : (tensor<3x3xi1>, tensor<3x3xi1>) -> tensor<3x3xi1>
    %8 = stablehlo.and %7, %2 : tensor<3x3xi1>
    %9 = stablehlo.add %0, %arg1 : tensor<3x3xf32>
    %10 = stablehlo.select %8, %9, %0 : tensor<3x3xi1>, tensor<3x3xf32>
    return %10 : tensor<3x3xf32>
  }
}

in order to test

def test_remainder_op():
    def module_remainder(a, b):
        return jnp.remainder(a, b) # or return jax.lax.rem(a, b)

    verify_module(module_remainder, [(3, 3), (3, 3)])
    verify_module(module_remainder, [(3, 3, 3), (3, 3, 3)])

@uazizTT @LPanosTT ? Do we need/want this complication with implicit broadcast or is it needless for such a simple test to go through that path?

@LPanosTT
Copy link
Contributor

LPanosTT commented Nov 1, 2024

@kmitrovicTT I thin we should just support both cases. Models may end up using either. What do all those extra ops do anyway?

@kmitrovicTT
Copy link
Contributor Author

I guess by looking at it some kind of implicit broadcast which in this case is no-op if shapes match. Might be that jax.lax.rem would break if I tried non-broadcastable shapes.

I ask this because I see some other tests resorted to jax.lax variant as well, so I'm wondering in general what should we do with these simple tests.

@LPanosTT
Copy link
Contributor

LPanosTT commented Nov 1, 2024

Well I would say just allow it. The function call will get inlined early on in the ttir to ttnn pass. @AleksKnezevic Thoughts?

@AleksKnezevic
Copy link
Contributor

@LPanosTT I don't follow your comment, allow what?

@LPanosTT
Copy link
Contributor

LPanosTT commented Nov 1, 2024

@LPanosTT I don't follow your comment, allow what?

Support both cases. jax.lax.rem and jnp.remainder. One of the cases just adds a ton of extra ops that work out to a nop f the input shapes match

@AleksKnezevic
Copy link
Contributor

Got it, yeah, I agree, both should be supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants