Skip to content

Commit

Permalink
Update llvm to e5d5146323ffaa13eb5185616c6ae5c36b69352d (#1628)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <tung@jp.ibm.com>
Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
tungld and AlexandreEichenberger authored Aug 23, 2022
1 parent 53a11a9 commit a9767f3
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 76 deletions.
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 061e0189a3dab6b1831a80d489ff1b15ad93aafb && cd ..
cd llvm-project && git checkout e5d5146323ffaa13eb5185616c6ae5c36b69352d && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 061e0189a3dab6b1831a80d489ff1b15ad93aafb && cd ..
cd llvm-project && git checkout e5d5146323ffaa13eb5185616c6ae5c36b69352d && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
31 changes: 7 additions & 24 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ struct ScalarOp<ONNXAddOp> {
using IOp = arith::AddIOp;
};

template <>
struct ScalarOp<ONNXAbsOp> {
using FOp = math::AbsFOp;
using IOp = math::AbsIOp;
};

template <>
struct ScalarOp<ONNXMulOp> {
using FOp = arith::MulFOp;
Expand Down Expand Up @@ -462,8 +468,8 @@ Value emitScalarOpFor<ONNXSoftsignOp>(ConversionPatternRewriter &rewriter,
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
Value operand = scalarOperands[0];

auto abs = rewriter.create<math::AbsOp>(loc, operand);
MathBuilder createMath(rewriter, loc);
Value abs = createMath.abs(operand);
Value one = createMath.constant(elementType, 1);
Value add = createMath.add(abs, one);
return createMath.div(operand, add);
Expand Down Expand Up @@ -573,29 +579,6 @@ Value emitScalarOpFor<ONNXMinOp>(ConversionPatternRewriter &rewriter,
return rewriter.create<arith::SelectOp>(loc, min, lhs, rhs);
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXAbsOp
//===----------------------------------------------------------------------===//
template <>
Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {
Value operand = scalarOperands[0];

if (elementType.isa<FloatType>()) {
return rewriter.create<math::AbsOp>(loc, operand);
} else if (elementType.isa<IntegerType>()) {
MathBuilder createMath(rewriter, loc);
Value zero = createMath.constant(elementType, 0);
auto lessThanZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, operand, zero);
Value negativeOperand = createMath.sub(zero, operand);
return createMath.select(lessThanZero, negativeOperand, operand);
} else {
llvm_unreachable("unsupported element type");
}
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXNegOp
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 12 additions & 6 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ namespace onnx_mlir {
// ONNX Integers as MLIR signless, and only flag the ONNX Unsigned Integer as
// MLIR unsigned integer.

Value MathBuilder::abs(Value val) const {
if (val.getType().isa<IntegerType>() || val.getType().isa<IndexType>())
return b.create<math::AbsIOp>(loc, val);
return b.create<math::AbsFOp>(loc, val);
}

Value MathBuilder::andi(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
return b.create<arith::AndIOp>(loc, lhs, rhs);
Expand All @@ -61,12 +67,14 @@ Value MathBuilder::add(Value lhs, Value rhs) const {
return b.create<arith::AddIOp>(loc, lhs, rhs);
return b.create<arith::AddFOp>(loc, lhs, rhs);
}

Value MathBuilder::sub(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return b.create<arith::SubIOp>(loc, lhs, rhs);
return b.create<arith::SubFOp>(loc, lhs, rhs);
}

Value MathBuilder::mul(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
Expand Down Expand Up @@ -877,7 +885,7 @@ Value LLVMBuilder::call(ArrayRef<Type> resultTypes, StringRef funcName,
// CallOp may return either 0 or 1 value.
if (resultTypes.empty())
return nullptr;
return callOp.getResult(0);
return callOp.getResult();
}

Value LLVMBuilder::call(ArrayRef<Type> resultTypes,
Expand All @@ -889,7 +897,7 @@ Value LLVMBuilder::call(ArrayRef<Type> resultTypes,
// CallOp may return either 0 or 1 value.
if (resultTypes.empty())
return nullptr;
return callOp.getResult(0);
return callOp.getResult();
}

void LLVMBuilder::condBr(Value cond, Block *trueBlock,
Expand Down Expand Up @@ -947,8 +955,7 @@ Value LLVMBuilder::constant(Type type, double val) const {

Value LLVMBuilder::extractValue(
Type resultType, Value container, ArrayRef<int64_t> position) const {
ArrayAttr posAttr = b.getI64ArrayAttr(position);
return b.create<LLVM::ExtractValueOp>(loc, resultType, container, posAttr);
return b.create<LLVM::ExtractValueOp>(loc, resultType, container, position);
}

LLVM::LLVMFuncOp LLVMBuilder::func(StringRef name, Type type) const {
Expand All @@ -973,9 +980,8 @@ Value LLVMBuilder::icmp(LLVM::ICmpPredicate cond, Value lhs, Value rhs) const {

Value LLVMBuilder::insertValue(Type resultType, Value container, Value val,
llvm::ArrayRef<int64_t> position) const {
ArrayAttr posAttr = b.getI64ArrayAttr(position);
return b.create<LLVM::InsertValueOp>(
loc, resultType, container, val, posAttr);
loc, resultType, container, val, position);
}

Value LLVMBuilder::load(Value addr) const {
Expand Down
2 changes: 2 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ struct MathBuilder final : DialectBuilder {
: DialectBuilder(b, loc) {}
MathBuilder(const DialectBuilder &db) : DialectBuilder(db) {}

mlir::Value abs(mlir::Value val) const;

mlir::Value andi(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const;

Expand Down
10 changes: 5 additions & 5 deletions test/mlir/conversion/onnx_to_mhlo/Math/GlobalPooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ func.func @test_global_average_pool(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x1x
// CHECK-LABEL: test_global_average_pool
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5x5xf32>) -> tensor<1x3x1x1xf32> {
// CHECK: [[VAR_2_:%.+]] = mhlo.reduce([[PARAM_0_]] init: [[VAR_1_:%.+]]) applies mhlo.add across dimensions = [2, 3] : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3xf32>
// CHECK: [[VAR_3_:%.+]] = "mhlo.reshape"([[VAR_2_]]) : (tensor<1x3xf32>) -> tensor<1x3x1x1xf32>
// CHECK: [[VAR_3_:%.+]] = mhlo.reshape [[VAR_2_]] : (tensor<1x3xf32>) -> tensor<1x3x1x1xf32>
}

// -----
Expand All @@ -18,11 +18,11 @@ func.func @test_global_average_pool_dyn_dims(%arg0: tensor<1x?x?x5xf32>) -> tens
return %0 : tensor<1x?x1x1xf32>
// CHECK-LABEL: test_global_average_pool_dyn_dims
// CHECK: [[VAR_3_:%.+]] = mhlo.reduce([[PARAM_0_:%.+]] init: [[VAR_2_:%.+]]) applies mhlo.add across dimensions = [2, 3] : (tensor<1x?x?x5xf32>, tensor<f32>) -> tensor<1x?xf32>
// CHECK: [[VAR_4_:%.+]] = "mhlo.dynamic_reshape"([[VAR_3_]], [[VAR_0_:%.+]]) : (tensor<1x?xf32>, tensor<4xi64>) -> tensor<1x?x1x1xf32>
// CHECK: [[VAR_4_:%.+]] = mhlo.dynamic_reshape [[VAR_3_]], [[VAR_0_:%.+]] : (tensor<1x?xf32>, tensor<4xi64>) -> tensor<1x?x1x1xf32>
// CHECK: [[VAR_5_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<1x?x?x5xf32> -> tensor<4xindex>
// CHECK: [[VAR_6_:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[VAR_1_:%.+]], [[VAR_5_]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xindex>) -> tensor<1x?x?x5xf32>
// CHECK: [[VAR_7_:%.+]] = mhlo.reduce([[VAR_6_]] init: [[VAR_2_]]) applies mhlo.add across dimensions = [2, 3] : (tensor<1x?x?x5xf32>, tensor<f32>) -> tensor<1x?xf32>
// CHECK: [[VAR_8_:%.+]] = "mhlo.dynamic_reshape"([[VAR_7_]], [[VAR_0_]]) : (tensor<1x?xf32>, tensor<4xi64>) -> tensor<1x?x1x1xf32>
// CHECK: [[VAR_8_:%.+]] = mhlo.dynamic_reshape [[VAR_7_]], [[VAR_0_]] : (tensor<1x?xf32>, tensor<4xi64>) -> tensor<1x?x1x1xf32>
}

// -----
Expand All @@ -34,7 +34,7 @@ func.func @test_global_max_pool(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x1x1xf3
// CHECK-LABEL: test_global_max_pool
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5x5xf32>) -> tensor<1x3x1x1xf32> {
// CHECK: [[VAR_1_:%.+]] = mhlo.reduce([[PARAM_0_]] init: [[VAR_0_:%.+]]) applies mhlo.maximum across dimensions = [2, 3] : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3xf32>
// CHECK: [[VAR_2_:%.+]] = "mhlo.reshape"([[VAR_1_]]) : (tensor<1x3xf32>) -> tensor<1x3x1x1xf32>
// CHECK: [[VAR_2_:%.+]] = mhlo.reshape [[VAR_1_]] : (tensor<1x3xf32>) -> tensor<1x3x1x1xf32>
}

// -----
Expand All @@ -45,5 +45,5 @@ func.func @test_global_max_pool_dyn_dims(%arg0: tensor<1x?x?x5xf32>) -> tensor<1
return %0 : tensor<1x?x1x1xf32>
// CHECK-LABEL: test_global_max_pool_dyn_dims
// CHECK: [[VAR_2_:%.+]] = mhlo.reduce([[PARAM_0_:%.+]] init: [[VAR_1_:%.+]]) applies mhlo.maximum across dimensions = [2, 3] : (tensor<1x?x?x5xf32>, tensor<f32>) -> tensor<1x?xf32>
// CHECK: [[VAR_3_:%.+]] = "mhlo.dynamic_reshape"([[VAR_2_]], [[VAR_0_:%.+]]) : (tensor<1x?xf32>, tensor<4xi64>) -> tensor<1x?x1x1xf32>
// CHECK: [[VAR_3_:%.+]] = mhlo.dynamic_reshape [[VAR_2_]], [[VAR_0_:%.+]] : (tensor<1x?xf32>, tensor<4xi64>) -> tensor<1x?x1x1xf32>
}
6 changes: 3 additions & 3 deletions test/mlir/conversion/onnx_to_mhlo/Math/Reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func.func @test_reducesum1(%arg0: tensor<3x2x2xf32>, %arg1: tensor<?xi64>) -> te
// CHECK-DAG: [[VAR_0:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1:%.+]] = mhlo.reduce([[PARAM_0:%.+]] init: [[VAR_0]]) applies mhlo.add across dimensions = [1] : (tensor<3x2x2xf32>, tensor<f32>) -> tensor<3x2xf32>
// CHECK-DAG: [[VAR_2:%.+]] = mhlo.constant dense<[3, 1, 2]> : tensor<3xi64>
// CHECK-DAG: [[VAR_3:%.+]] = "mhlo.dynamic_reshape"([[VAR_1]], [[VAR_2]]) : (tensor<3x2xf32>, tensor<3xi64>) -> tensor<3x1x2xf32>
// CHECK-DAG: [[VAR_3:%.+]] = mhlo.dynamic_reshape [[VAR_1]], [[VAR_2]] : (tensor<3x2xf32>, tensor<3xi64>) -> tensor<3x1x2xf32>
}

// -----
Expand All @@ -60,7 +60,7 @@ func.func @test_reducesum2(%arg0: tensor<3x2x2xf32>, %arg1: tensor<?xi64>) -> te
// CHECK-DAG: [[VAR_0:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1:%.+]] = mhlo.reduce([[PARAM_0:%.+]] init: [[VAR_0]]) applies mhlo.add across dimensions = [1] : (tensor<3x2x2xf32>, tensor<f32>) -> tensor<3x2xf32>
// CHECK-DAG: [[VAR_2:%.+]] = mhlo.constant dense<[3, 1, 2]> : tensor<3xi64>
// CHECK-DAG: [[VAR_3:%.+]] = "mhlo.dynamic_reshape"([[VAR_1]], [[VAR_2]]) : (tensor<3x2xf32>, tensor<3xi64>) -> tensor<3x1x2xf32>
// CHECK-DAG: [[VAR_3:%.+]] = mhlo.dynamic_reshape [[VAR_1]], [[VAR_2]] : (tensor<3x2xf32>, tensor<3xi64>) -> tensor<3x1x2xf32>
}

func.func @test_reducemean(%arg0 : tensor<3x2x2xf32>) -> tensor<3x2xf32> {
Expand All @@ -86,4 +86,4 @@ func.func @test_reducemean2(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[VAR_0_]], [[VAR_3_]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: [[VAR_5_:%.+]] = mhlo.reduce([[VAR_4_]] init: [[VAR_1_]]) applies mhlo.add across dimensions = [1] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK-NEXT: [[VAR_6_:%.+]] = mhlo.divide [[VAR_2_]], [[VAR_5_]] : tensor<?x?xf32>
}
}
12 changes: 6 additions & 6 deletions test/mlir/conversion/onnx_to_mhlo/Math/Softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ func.func @test_softmax(%arg0 : tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1_:%.+]] = mhlo.constant dense<0xFF800000> : tensor<f32>
// CHECK-NEXT: [[VAR_2_:%.+]] = mhlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies mhlo.maximum across dimensions = [1] : (tensor<10x20x30xf32>, tensor<f32>) -> tensor<10x30xf32>
// CHECK-NEXT: [[VAR_3_:%.+]] = "mhlo.reshape"([[VAR_2_]]) : (tensor<10x30xf32>) -> tensor<10x1x30xf32>
// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.reshape [[VAR_2_]] : (tensor<10x30xf32>) -> tensor<10x1x30xf32>
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.broadcast_in_dim"([[VAR_3_]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x1x30xf32>) -> tensor<10x20x30xf32>
// CHECK-NEXT: [[VAR_5_:%.+]] = mhlo.subtract [[PARAM_0_]], [[VAR_4_]] : tensor<10x20x30xf32>
// CHECK-NEXT: [[VAR_6_:%.+]] = mhlo.exponential [[VAR_5_]] : tensor<10x20x30xf32>
// CHECK-NEXT: [[VAR_7_:%.+]] = mhlo.reduce([[VAR_6_]] init: [[VAR_0_]]) applies mhlo.add across dimensions = [1] : (tensor<10x20x30xf32>, tensor<f32>) -> tensor<10x30xf32>
// CHECK-NEXT: [[VAR_8_:%.+]] = "mhlo.reshape"([[VAR_7_]]) : (tensor<10x30xf32>) -> tensor<10x1x30xf32>
// CHECK-NEXT: [[VAR_8_:%.+]] = mhlo.reshape [[VAR_7_]] : (tensor<10x30xf32>) -> tensor<10x1x30xf32>
// CHECK-NEXT: [[VAR_9_:%.+]] = "mhlo.broadcast_in_dim"([[VAR_8_]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x1x30xf32>) -> tensor<10x20x30xf32>
// CHECK-NEXT: [[VAR_10_:%.+]] = mhlo.divide [[VAR_6_]], [[VAR_9_]] : tensor<10x20x30xf32>
// CHECK-NEXT: return [[VAR_10_]] : tensor<10x20x30xf32>
Expand All @@ -28,7 +28,7 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
// CHECK-DAG: [[VAR_1_:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_2_:%.+]] = mhlo.constant dense<0xFF800000> : tensor<f32>
// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.reduce([[PARAM_0_]] init: [[VAR_2_]]) applies mhlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.dynamic_reshape"([[VAR_3_]], [[VAR_0_:%.+]]) : (tensor<?x30xf32>, tensor<3xi64>) -> tensor<?x1x30xf32>
// CHECK-NEXT: [[VAR_4_:%.+]] = mhlo.dynamic_reshape [[VAR_3_]], [[VAR_0_:%.+]] : (tensor<?x30xf32>, tensor<3xi64>) -> tensor<?x1x30xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK-DAG: [[VAR_6_:%.+]] = shape.shape_of [[VAR_4_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// CHECK-NEXT: [[VAR_7_:%.+]] = shape.broadcast [[VAR_5_]], [[VAR_6_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
Expand All @@ -37,7 +37,7 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
// CHECK-NEXT: [[VAR_10_:%.+]] = mhlo.subtract [[VAR_8_]], [[VAR_9_]] : tensor<?x20x30xf32>
// CHECK-NEXT: [[VAR_11_:%.+]] = mhlo.exponential [[VAR_10_]] : tensor<?x20x30xf32>
// CHECK-NEXT: [[VAR_12_:%.+]] = mhlo.reduce([[VAR_11_]] init: [[VAR_1_]]) applies mhlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// CHECK-NEXT: [[VAR_13_:%.+]] = "mhlo.dynamic_reshape"([[VAR_12_]], [[VAR_0_]]) : (tensor<?x30xf32>, tensor<3xi64>) -> tensor<?x1x30xf32>
// CHECK-NEXT: [[VAR_13_:%.+]] = mhlo.dynamic_reshape [[VAR_12_]], [[VAR_0_]] : (tensor<?x30xf32>, tensor<3xi64>) -> tensor<?x1x30xf32>
// CHECK-DAG: [[VAR_14_:%.+]] = shape.shape_of [[VAR_11_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK-DAG: [[VAR_15_:%.+]] = shape.shape_of [[VAR_13_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// CHECK-NEXT: [[VAR_16_:%.+]] = shape.broadcast [[VAR_14_]], [[VAR_15_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
Expand All @@ -53,10 +53,10 @@ func.func @test_softmax_2d(%arg0 : tensor<1x10xf32>) -> tensor<1x10xf32> {
// CHECK-LABEL: func @test_softmax_2d
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x10xf32>) -> tensor<1x10xf32> {
// CHECK: [[VAR_2_:%.+]] = mhlo.reduce([[PARAM_0_]] init: [[VAR_1_:%.+]]) applies mhlo.maximum across dimensions = [1] : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
// CHECK: [[VAR_3_:%.+]] = "mhlo.reshape"([[VAR_2_]]) : (tensor<1xf32>) -> tensor<1x1xf32>
// CHECK: [[VAR_3_:%.+]] = mhlo.reshape [[VAR_2_]] : (tensor<1xf32>) -> tensor<1x1xf32>
// CHECK: [[VAR_4_:%.+]] = "mhlo.broadcast_in_dim"([[VAR_3_]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x10xf32>
// CHECK: [[VAR_7_:%.+]] = mhlo.reduce([[VAR_6_]] init: [[VAR_0_:%.+]]) applies mhlo.add across dimensions = [1] : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
// CHECK: [[VAR_8_:%.+]] = "mhlo.reshape"([[VAR_7_]]) : (tensor<1xf32>) -> tensor<1x1xf32>
// CHECK: [[VAR_8_:%.+]] = mhlo.reshape [[VAR_7_]] : (tensor<1xf32>) -> tensor<1x1xf32>
// CHECK: [[VAR_9_:%.+]] = "mhlo.broadcast_in_dim"([[VAR_8_]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x10xf32>
// CHECK: [[VAR_10_:%.+]] = mhlo.divide [[VAR_6_]], [[VAR_9_]] : tensor<1x10xf32>
}
Expand Down
Loading

0 comments on commit a9767f3

Please sign in to comment.