Skip to content

Commit

Permalink
PR #10649: [ROCm] Triton in XLA for ROCm - ir_emitter_triton related …
Browse files Browse the repository at this point in the history
…changes.

Imported from GitHub PR openxla/xla#10649

Second commit of the series for enabling Triton in XLA for ROCm.

Copybara import of the project:

--
23d442f83c731cd86131bcd1d91c4e3d7cc42468 by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - ir_emitter_triton related changes.

Merging this change closes #10649

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10649 from ROCm:rocm_triton_backend_3 23d442f83c731cd86131bcd1d91c4e3d7cc42468
PiperOrigin-RevId: 621830985
  • Loading branch information
zoranjovanovic-ns authored and tensorflower-gardener committed Apr 4, 2024
1 parent a4d9df4 commit acdaa12
Show file tree
Hide file tree
Showing 24 changed files with 2,478 additions and 136 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/quantization/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ cc_library(
"passes/restore_function_name.cc",
"passes/unfuse_mhlo_batch_norm.cc",
"passes/unwrap_xla_call_module_op.cc",
"passes/xla_call_module_to_call.cc",
],
hdrs = [
"passes/passes.h",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> {
];
}

def XlaCallModuleToCallPass : Pass<"stablehlo-xla-call-module-to-call", "ModuleOp"> {
let summary = "Convert XlaCallModuleOp to func.call op";
let dependentDialects = [
"TF::TensorFlowDialect",
];
}

def UnwrapXlaCallModuleOpPass : Pass<"stablehlo-unwrap-xla-call-module-op", "ModuleOp"> {
let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns.";
let dependentDialects = ["TF::TensorFlowDialect"];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ void QuantizeCompositeFunctionsPass::runOnOperation() {
pm.addPass(createQuantizePass(quantize_options));
pm.addNestedPass<func::FuncOp>(createPostQuantizePass());

// Convert XlaCallModuleOps lifted but not quantized to func.call op.
// The reasons these ops are not quantized may be:
// 1. Disabled due to selective quantization.
// 2. Not supported, e.g. add op for server.
pm.addPass(createXlaCallModuleToCallPass());
ModuleOp module_op = getOperation();
if (const absl::Status pm_run_status =
RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir::quant::stablehlo {

#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc"

namespace {

// Converts XlaCallModuleOps to func.call.
class XlaCallModuleToCallPass
: public impl::XlaCallModuleToCallPassBase<XlaCallModuleToCallPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass)

explicit XlaCallModuleToCallPass() = default;

private:
void runOnOperation() override;
};

// Converts XlaCallModuleOps to func.call.
class XlaCallModuleOpToCallOp : public OpRewritePattern<TF::XlaCallModuleOp> {
public:
explicit XlaCallModuleOpToCallOp(MLIRContext* context)
: OpRewritePattern<TF::XlaCallModuleOp>(context) {}

LogicalResult matchAndRewrite(TF::XlaCallModuleOp op,
PatternRewriter& rewriter) const override {
auto module_op = op->getParentOfType<ModuleOp>();
SymbolTable symbol_table(module_op);

auto entry_func_op = dyn_cast_or_null<func::FuncOp>(
symbol_table.lookup(GetEntryFunctionName(op)));
if (!entry_func_op) return failure();

// Replace the XlaCallModuleOp with a new CallOp.
rewriter.replaceOpWithNewOp<func::CallOp>(op, entry_func_op, op.getArgs());
return success();
}
};

void XlaCallModuleToCallPass::runOnOperation() {
ModuleOp module_op = getOperation();
MLIRContext* ctx = module_op.getContext();
RewritePatternSet patterns(&getContext());
patterns.add<XlaCallModuleOpToCallOp>(ctx);
if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace
} // namespace mlir::quant::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,126 @@ def test_conv_weight_only_model(
0.35,
)

@parameterized.parameters(
testing.parameter_combinations([{
'shape_dynamic': (
False,
True,
),
}])
)
@test_util.run_in_graph_and_eager_modes
def test_add_ptq_model(
self,
shape_dynamic: bool,
):
input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3)
self._create_add_model(
input_shape,
self._input_saved_model_path,
)

# Generate model input data.
rng = np.random.default_rng(seed=42)
static_input_shape = [dim if dim is not None else 2 for dim in input_shape]

def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(100):
yield {
'input_tensor': rng.uniform(
low=0.0, high=1.0, size=static_input_shape
).astype(np.float32)
}

dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)

config = qc.QuantizationConfig(
static_range_ptq_preset=qc.StaticRangePtqPreset(
representative_datasets=[
qc.RepresentativeDatasetConfig(
tf_record=qc.TfRecordFile(path=dataset_path)
)
],
),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)

self.assertEqual(
self._get_num_xla_call_module_op(self._output_saved_model_path), 1
)
module_str = self._extract_first_xla_call_module_op(
self._output_saved_model_path
)

# Check add is not quantized.
self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str))

@parameterized.parameters(
testing.parameter_combinations([{
'shape_dynamic': (
False,
True,
),
}])
)
@test_util.run_in_graph_and_eager_modes
def test_add_weight_only_model(
self,
shape_dynamic: bool,
):
input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3)
self._create_add_model(
input_shape,
self._input_saved_model_path,
)

# Generate model input data.
rng = np.random.default_rng(seed=42)
static_input_shape = [dim if dim is not None else 2 for dim in input_shape]

def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(100):
yield {
'input_tensor': rng.uniform(
low=0.0, high=1.0, size=static_input_shape
).astype(np.float32)
}

dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)

config = qc.QuantizationConfig(
weight_only_ptq_preset=qc.WeightOnlyPtqPreset(),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)

self.assertEqual(
self._get_num_xla_call_module_op(self._output_saved_model_path), 1
)
module_str = self._extract_first_xla_call_module_op(
self._output_saved_model_path
)

# Check add is not quantized.
self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str), module_str)


if __name__ == '__main__':
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def _extract_first_xla_call_module_op(
return str(stablehlo_module)
raise ValueError('No XlaCallModule found in saved model.')

def _get_num_xla_call_module_op(self, output_saved_model_path: str) -> int:
"""Gets the number of XlaCallModule ops in the output saved model."""
root = load.load(output_saved_model_path)
tf_graph_def = root.signatures['serving_default'].graph.as_graph_def()
count = 0
for node_def in tf_graph_def.node:
if node_def.op == 'XlaCallModule':
count += 1
for function in tf_graph_def.library.function:
for node_def in function.node_def:
if node_def.op == 'XlaCallModule':
count += 1
return count

def _create_matmul_model(
self,
input_shape: Sequence[int],
Expand Down Expand Up @@ -339,6 +353,42 @@ def __call__(

return GatherModel(use_variable)

def _create_add_model(
self,
shape: Sequence[int],
saved_model_path: str,
) -> module.Module:
class AddModel(module.Module):
"""A simple model with a single add."""

def __init__(self):
pass

@def_function.function
def add(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
"""Performs an add operation.
Args:
input_tensor: Input tensor to perform add on.
Returns:
A map of: output key -> output result.
"""
out = math_ops.add(input_tensor, input_tensor)
return {'output': out}

model = AddModel()
saved_model_save.save(
model,
saved_model_path,
signatures=model.add.get_concrete_function(
tensor_spec.TensorSpec(
shape=shape, dtype=dtypes.float32, name='input_tensor'
)
),
)
return model

# Prepares sample einsum input data shapes.
# This function returns:
# 1. Shape for input 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1:
}
// CHECK-LABEL: func.func @main
// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32>

// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{.*}}> : tensor<1024x3xf32>
// CHECK: "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]])

// CHECK: func.func private @composite_dot_general_fn_1
// CHECK-SAME: attributes {_from_xla_call_module}
// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general
// CHECK-SAME: contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
// CHECK: stablehlo.dot_general %[[ARG_0]], %[[CONST_0]]
// CHECK-NOT: tf.XlaCallModule
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ module attributes {tf_saved_model.semantics} {

// -----

// Tests that XlaCallModule op is not quantized without the quantfork.stats ops.
// Tests that XlaCallModule op is not quantized and converted to func.call without the quantfork.stats ops.

module attributes {tf_saved_model.semantics} {
func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} {
Expand All @@ -728,8 +728,8 @@ module attributes {tf_saved_model.semantics} {

// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"}
// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32>
// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: return %[[XLA_CALL_MODULE_0]]
// CHECK: %[[CALL:.+]] = call @composite_dot_general_fn(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: return %[[CALL]]

func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-xla-call-module-to-call | FileCheck %s

// -----

// Tests composite tf.XlaCallModule is converted to func.call.

module {
// CHECK-LABEL: func.func @main
func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> {
// CHECK: call @composite_dot_general_fn_1
// CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
// CHECK-NOT: tf.XlaCallModule
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32>
%2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
return %2 : tensor<1x3xf32>
}
// CHECK-LABEL: func.func private @composite_dot_general_fn_1
// CHECK-SAME: -> tensor<1x3xf32>
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
return %0 : tensor<1x3xf32>
}
}
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 253
_version = 254

# Version number for MLIR:Python components.
mlir_api_version = 55
Expand Down
Loading

0 comments on commit acdaa12

Please sign in to comment.