Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[StableHLO] gather op with float32 operand #1349

Open
mmanzoorTT opened this issue Nov 20, 2024 · 0 comments
Open

[StableHLO] gather op with float32 operand #1349

mmanzoorTT opened this issue Nov 20, 2024 · 0 comments
Assignees
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion

Comments

@mmanzoorTT
Copy link
Contributor

tt-metal only supports embedding for bfloat16 data type. We have a use case in our tt-torch models where the input operand is float32 which causes failure. The stablehlo graph is below

module {
  func.func @main(%arg0: tensor<2048x32xf32>, %arg1: tensor<1x5xi64>) -> tensor<1x5x32xf32> {
    %0 = stablehlo.reshape %arg1 : (tensor<1x5xi64>) -> tensor<1x5x1xi64>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 32>}> : (tensor<2048x32xf32>, tensor<1x5x1xi64>) -> tensor<1x5x32xf32>
    return %1 : tensor<1x5x32xf32>
  }
}

We may add a typecast to convert input operand to bfloat16.

@mmanzoorTT mmanzoorTT added bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion labels Nov 20, 2024
@mmanzoorTT mmanzoorTT added this to the [Third Party] HLO + XLA milestone Nov 20, 2024
@mmanzoorTT mmanzoorTT self-assigned this Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion
Projects
None yet
Development

No branches or pull requests

1 participant