-
Notifications
You must be signed in to change notification settings - Fork 74.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #10649: [ROCm] Triton in XLA for ROCm - ir_emitter_triton related …
…changes. Imported from GitHub PR openxla/xla#10649 Second commit of the series for enabling Triton in XLA for ROCm. Copybara import of the project: -- 23d442f83c731cd86131bcd1d91c4e3d7cc42468 by Zoran Jovanovic <zjovanov@amd.com>: [ROCm] Triton in XLA for ROCm - ir_emitter_triton related changes. Merging this change closes #10649 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10649 from ROCm:rocm_triton_backend_3 23d442f83c731cd86131bcd1d91c4e3d7cc42468 PiperOrigin-RevId: 621830985
- Loading branch information
1 parent
a4d9df4
commit acdaa12
Showing
24 changed files
with
2,478 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <utility> | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project | ||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project | ||
#include "mlir/IR/MLIRContext.h" // from @llvm-project | ||
#include "mlir/IR/OpDefinition.h" // from @llvm-project | ||
#include "mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/IR/SymbolTable.h" // from @llvm-project | ||
#include "mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/Support/LogicalResult.h" // from @llvm-project | ||
#include "mlir/Support/TypeID.h" // from @llvm-project | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" | ||
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" | ||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" | ||
|
||
namespace mlir::quant::stablehlo { | ||
|
||
#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS | ||
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" | ||
|
||
namespace { | ||
|
||
// Converts XlaCallModuleOps to func.call. | ||
class XlaCallModuleToCallPass | ||
: public impl::XlaCallModuleToCallPassBase<XlaCallModuleToCallPass> { | ||
public: | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass) | ||
|
||
explicit XlaCallModuleToCallPass() = default; | ||
|
||
private: | ||
void runOnOperation() override; | ||
}; | ||
|
||
// Converts XlaCallModuleOps to func.call. | ||
class XlaCallModuleOpToCallOp : public OpRewritePattern<TF::XlaCallModuleOp> { | ||
public: | ||
explicit XlaCallModuleOpToCallOp(MLIRContext* context) | ||
: OpRewritePattern<TF::XlaCallModuleOp>(context) {} | ||
|
||
LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, | ||
PatternRewriter& rewriter) const override { | ||
auto module_op = op->getParentOfType<ModuleOp>(); | ||
SymbolTable symbol_table(module_op); | ||
|
||
auto entry_func_op = dyn_cast_or_null<func::FuncOp>( | ||
symbol_table.lookup(GetEntryFunctionName(op))); | ||
if (!entry_func_op) return failure(); | ||
|
||
// Replace the XlaCallModuleOp with a new CallOp. | ||
rewriter.replaceOpWithNewOp<func::CallOp>(op, entry_func_op, op.getArgs()); | ||
return success(); | ||
} | ||
}; | ||
|
||
void XlaCallModuleToCallPass::runOnOperation() { | ||
ModuleOp module_op = getOperation(); | ||
MLIRContext* ctx = module_op.getContext(); | ||
RewritePatternSet patterns(&getContext()); | ||
patterns.add<XlaCallModuleOpToCallOp>(ctx); | ||
if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
} // namespace | ||
} // namespace mlir::quant::stablehlo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-xla-call-module-to-call | FileCheck %s | ||
|
||
// ----- | ||
|
||
// Tests composite tf.XlaCallModule is converted to func.call. | ||
|
||
module { | ||
// CHECK-LABEL: func.func @main | ||
func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { | ||
// CHECK: call @composite_dot_general_fn_1 | ||
// CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> | ||
// CHECK-NOT: tf.XlaCallModule | ||
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> | ||
%2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> | ||
return %2 : tensor<1x3xf32> | ||
} | ||
// CHECK-LABEL: func.func private @composite_dot_general_fn_1 | ||
// CHECK-SAME: -> tensor<1x3xf32> | ||
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { | ||
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> | ||
return %0 : tensor<1x3xf32> | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.