From 0f69af756abd7136453a11cf169939d03e7c5f7d Mon Sep 17 00:00:00 2001 From: Jules Merckx Date: Tue, 17 Dec 2024 14:03:32 +0100 Subject: [PATCH 1/3] add enzyme.broadcast to `stablehlo.broadcast_in_dim`/`mhlo.broadcast` conversion in `arith-raise`. --- src/enzyme_ad/jax/Passes/ArithRaising.cpp | 27 +++++++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 3 ++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/ArithRaising.cpp b/src/enzyme_ad/jax/Passes/ArithRaising.cpp index 1fad91c3..93c890cd 100644 --- a/src/enzyme_ad/jax/Passes/ArithRaising.cpp +++ b/src/enzyme_ad/jax/Passes/ArithRaising.cpp @@ -19,6 +19,8 @@ #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -91,6 +93,31 @@ struct ArithRaisingPass : public ArithRaisingPassBase { constOp.erase(); } }); + op->walk([=](enzyme::BroadcastOp broadcastOp) { + OpBuilder builder(broadcastOp); + Value newBroadcastOp; + if (use_stablehlo) { + SmallVector broadcastDims; + auto shape = broadcastOp.getInput().getType().cast().getShape(); + broadcastDims.reserve(shape.size()); + for (auto en : llvm::enumerate(shape)) { + // original dimensions end up one further because the batch dimension is prepended: + broadcastDims.push_back(en.index() + 1); + } + newBroadcastOp = builder.create( + broadcastOp.getLoc(), + broadcastOp.getType(), + broadcastOp.getInput(), + builder.getDenseI64ArrayAttr(broadcastDims) + ); + } else { + newBroadcastOp = builder.create( + broadcastOp.getLoc(), broadcastOp.getInput(), + builder.getI64TensorAttr({broadcastOp.getWidth()})); + } + broadcastOp.replaceAllUsesWith(newBroadcastOp); + broadcastOp.erase(); + }); } }; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 2aafc894..12af5e9e 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -17,7 +17,8 @@ def ArithRaisingPass : Pass<"arith-raise"> { "arith::ArithDialect", "mhlo::MhloDialect", "stablehlo::StablehloDialect", - "chlo::ChloDialect" + "chlo::ChloDialect", + "enzyme::EnzymeDialect", ]; let constructor = "mlir::enzyme::createArithRaisingPass()"; let options = [ From a05b2f2fde96c7f0128fce96fb412a36a7f8dd6f Mon Sep 17 00:00:00 2001 From: Jules Merckx Date: Wed, 18 Dec 2024 16:36:25 +0100 Subject: [PATCH 2/3] test --- test/lit_tests/broadcastdiff.mlir | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 test/lit_tests/broadcastdiff.mlir diff --git a/test/lit_tests/broadcastdiff.mlir b/test/lit_tests/broadcastdiff.mlir new file mode 100644 index 00000000..ae23b7df --- /dev/null +++ b/test/lit_tests/broadcastdiff.mlir @@ -0,0 +1,16 @@ +// RUN: enzymexlamlir-opt --arith-raise %s | FileCheck %s + +module { + func.func @main(%arg0: tensor, %arg1: tensor<2xf64>) -> tensor<2xf64> { + %0 = "enzyme.broadcast"(%arg0) <{width = 2 : i64}> : (tensor) -> tensor<2xf64> + %1 = arith.addf %0, %arg1 : tensor<2xf64> + return %1 : tensor<2xf64> + } +} + +// CHECK: func.func @main(%arg0: tensor, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor) -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = stablehlo.add %[[i0:.+]], %arg1 : tensor<2xf64> +// CHECK-NEXT: return %[[i1:.+]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK-NEXT: } From 656c855bb308a404635a9d8b2b20f5b9a9ce4a47 Mon Sep 17 00:00:00 2001 From: Jules Merckx Date: Fri, 20 Dec 2024 09:38:18 +0100 Subject: [PATCH 3/3] formatting --- src/enzyme_ad/jax/Passes/ArithRaising.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ArithRaising.cpp b/src/enzyme_ad/jax/Passes/ArithRaising.cpp index 93c890cd..fc19d411 100644 --- a/src/enzyme_ad/jax/Passes/ArithRaising.cpp +++ b/src/enzyme_ad/jax/Passes/ArithRaising.cpp @@ -10,6 +10,8 @@ // ops. //===----------------------------------------------------------------------===// +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" @@ -19,8 +21,6 @@ #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "Enzyme/MLIR/Dialect/Dialect.h" -#include "Enzyme/MLIR/Dialect/Ops.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -98,18 +98,17 @@ struct ArithRaisingPass : public ArithRaisingPassBase { Value newBroadcastOp; if (use_stablehlo) { SmallVector broadcastDims; - auto shape = broadcastOp.getInput().getType().cast().getShape(); + auto shape = + broadcastOp.getInput().getType().cast().getShape(); broadcastDims.reserve(shape.size()); for (auto en : llvm::enumerate(shape)) { - // original dimensions end up one further because the batch dimension is prepended: + // original dimensions end up one further because the batch dimension + // is prepended: broadcastDims.push_back(en.index() + 1); } newBroadcastOp = builder.create( - broadcastOp.getLoc(), - broadcastOp.getType(), - broadcastOp.getInput(), - builder.getDenseI64ArrayAttr(broadcastDims) - ); + broadcastOp.getLoc(), broadcastOp.getType(), broadcastOp.getInput(), + builder.getDenseI64ArrayAttr(broadcastDims)); } else { newBroadcastOp = builder.create( broadcastOp.getLoc(), broadcastOp.getInput(),