Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[StableHLO] Unhandled use cases for stablehlo.gather op #1350

Open
mmanzoorTT opened this issue Nov 20, 2024 · 2 comments
Open

[StableHLO] Unhandled use cases for stablehlo.gather op #1350

mmanzoorTT opened this issue Nov 20, 2024 · 2 comments
Assignees
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion

Comments

@mmanzoorTT
Copy link
Contributor

Following are the stablehlo graphs for stablehlo.gather op (coming from PyTorch models through tt-torch) which tt-mlir can't handle and mark them illegal explicitly.

Example 1:

func.func @test1(%arg0: tensor<1x7x2xbf16>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<1x2xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<1xi64>) -> tensor<1x1xi64>
    %1 = stablehlo.reshape %arg2 : (tensor<1xi64>) -> tensor<1x1xi64>
    %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
    %3 = "stablehlo.gather"(%arg0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 2>}> : (tensor<1x7x2xbf16>, tensor<1x2xi64>) -> tensor<1x2xbf16>
    return %3 : tensor<1x2xbf16>
  }

Example 2:

  func.func @test2(%arg0: tensor<2x7x512xbf16>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>) -> tensor<2x512xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<2xi64>) -> tensor<2x1xi64>
    %1 = stablehlo.reshape %arg2 : (tensor<2xi64>) -> tensor<2x1xi64>
    %2 = "stablehlo.concatenate" (%0, %1) {dimension = 1 : i64 } : (tensor<2x1xi64>, tensor<2x1xi64>) -> tensor<2x2xi64>
    %3 = "stablehlo.gather"(%arg0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 512>}> : (tensor<2x7x512xbf16>, tensor<2x2xi64>) -> tensor<2x512xbf16>
    return %3 : tensor<2x512xbf16>
  }

Example 3:

  func.func @test3(%arg0: tensor<732x12xbf16>, %arg1: tensor<38809xi64>) -> tensor<38809x12xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<38809xi64>) -> tensor<38809x1xi64>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 12>}> : (tensor<732x12xbf16>, tensor<38809x1xi64>) -> tensor<38809x12xbf16>
    return %1 : tensor<38809x12xbf16>
  }

Example 4:

  func.func @test4(%arg0: tensor<732x16xbf16>, %arg1: tensor<38809xi64>) -> tensor<38809x16xbf16> {
    %0 = stablehlo.reshape %arg1 : (tensor<38809xi64>) -> tensor<38809x1xi64>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 16>}> : (tensor<732x16xbf16>, tensor<38809x1xi64>) -> tensor<38809x16xbf16>
    return %1 : tensor<38809x16xbf16>
  }
@mmanzoorTT mmanzoorTT added bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion labels Nov 20, 2024
@mmanzoorTT mmanzoorTT added this to the [Third Party] HLO + XLA milestone Nov 20, 2024
@ddilbazTT
Copy link
Contributor

These might be cases where we convert gather to reshape + slice + concatenate.

These examples are against the constraints we check for lowering to embedding:

  • Output tensor shape: Multi-dimensional with last dimension as embedding size/ hiddenDim
  • Slice sizes: Must be [1, hiddenDim], where hiddenDim matches last output dimension
  • Offset dimensions: Strictly [2]
  • Collapsed slice dimensions: Strictly [0]
  • Start indices shape: Must be compatible with output shape
    • startIndices.size() < output.size()
    • if startIndices.size() == output.size(), then startIndices[-1] == 1
    • Last dimension of start indices can be reduced by reshape op.
    • This is due to embedding weights requiring to have smaller size than output shape

@ddilbazTT
Copy link
Contributor

I gathered cases where we don't support gather, so we need to implement a different method. stablehlo_gather_lines.txt
Probably something like this:

  1. Reshape and flatten the input tensor.
  2. Reshape or permute the indices tensor to match the expected dimensionality for gathering.
  3. Extract slices based on the reshaped indices.
  4. Reshape the result to match the target output shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion
Projects
None yet
Development

No branches or pull requests

2 participants