Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine #17080

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Lunderberg
Copy link
Contributor

This is a follow-up commit to
#16637, which updated relax.transform.FuseOps to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that relax.transform.FuseOps produces well-formed Relax functions, these additional arguments can break some kernel implementations.

This commit implements a new transform
RemoveSymbolicExpressionsInSubroutine to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.

For example, consider the following Relax function:

@R.function
def func(
    data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):

    batch_size = T.int64()
    seq_len = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
    return output

The data tensor may be used to infer hidden_size, but cannot be used to infer batch_size or seq_len. The R.Shape parameter exists solely to define batch_size and seq_len, since all symbolic variables must be defined. However, neither batch_size nor seq_len are ever used outside of the expression batch_size * seq_len, and the value of batch_size * seq_len could be inferred from the shape of the data tensor.

This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the dummy_arg: R.Shape be entirely unused, so a later use of relax.transform.RemoveUnusedParameters() can remove the parameter altogether.

@R.function
def func(
    data: R.Tensor(["data_dim0", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ):

    data_dim0 = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
    return output

@Lunderberg
Copy link
Contributor Author

This transform is intended to be used in the implementation of #16450, as recommended here.

@Lunderberg Lunderberg requested a review from sunggg June 18, 2024 18:59
This is a follow-up commit to
apache#16637, which updated
`relax.transform.FuseOps` to provide additional parameters defining
symbolic variables required by the fused functions.  While this
ensures that `relax.transform.FuseOps` produces well-formed Relax
functions, these additional arguments can break some kernel
implementations.

This commit implements a new transform
`RemoveSymbolicExpressionsInSubroutine` to resolve this issue.  This
transform identifies function arguments whose sole purpose is to
compute a symbolic expression, when that symbolic expression could be
inferred from tensor shapes.

For example, consider the following Relax function:

```python
@R.function
def func(
    data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):

    batch_size = T.int64()
    seq_len = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
    return output
```

The `data` tensor may be used to infer `hidden_size`, but cannot be
used to infer `batch_size` or `seq_len`.  The `R.Shape` parameter
exists solely to define `batch_size` and `seq_len`, since all symbolic
variables must be defined.  However, neither `batch_size` nor
`seq_len` are ever used outside of the expression `batch_size *
seq_len`, and the value of `batch_size * seq_len` could be inferred
from the shape of the `data` tensor.

This new transform identifies cases where an argument is otherwise
unnecessary, and replaces the symbolic expression with a new
argument.  This makes the `dummy_arg: R.Shape` be entirely unused, so
a later use of `relax.transform.RemoveUnusedParameters()` can remove
the parameter altogether.

```python
@R.function
def func(
    data: R.Tensor(["data_dim0", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ):

    data_dim0 = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
    return output
```
@Lunderberg Lunderberg force-pushed the relax_remove_symbolic_expr_in_subroutine branch from 8f484d2 to 27a6820 Compare September 11, 2024 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant