diff --git a/src/enzyme_ad/jax/Passes/ArithRaising.cpp b/src/enzyme_ad/jax/Passes/ArithRaising.cpp index 1fad91c3..fc19d411 100644 --- a/src/enzyme_ad/jax/Passes/ArithRaising.cpp +++ b/src/enzyme_ad/jax/Passes/ArithRaising.cpp @@ -10,6 +10,8 @@ // ops. //===----------------------------------------------------------------------===// +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" @@ -91,6 +93,30 @@ struct ArithRaisingPass : public ArithRaisingPassBase { constOp.erase(); } }); + op->walk([=](enzyme::BroadcastOp broadcastOp) { + OpBuilder builder(broadcastOp); + Value newBroadcastOp; + if (use_stablehlo) { + SmallVector broadcastDims; + auto shape = + broadcastOp.getInput().getType().cast().getShape(); + broadcastDims.reserve(shape.size()); + for (auto en : llvm::enumerate(shape)) { + // original dimensions end up one further because the batch dimension + // is prepended: + broadcastDims.push_back(en.index() + 1); + } + newBroadcastOp = builder.create( + broadcastOp.getLoc(), broadcastOp.getType(), broadcastOp.getInput(), + builder.getDenseI64ArrayAttr(broadcastDims)); + } else { + newBroadcastOp = builder.create( + broadcastOp.getLoc(), broadcastOp.getInput(), + builder.getI64TensorAttr({broadcastOp.getWidth()})); + } + broadcastOp.replaceAllUsesWith(newBroadcastOp); + broadcastOp.erase(); + }); } }; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 2aafc894..12af5e9e 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -17,7 +17,8 @@ def ArithRaisingPass : Pass<"arith-raise"> { "arith::ArithDialect", "mhlo::MhloDialect", "stablehlo::StablehloDialect", - "chlo::ChloDialect" + "chlo::ChloDialect", + "enzyme::EnzymeDialect", ]; let constructor = "mlir::enzyme::createArithRaisingPass()"; let options = [ diff --git a/test/lit_tests/broadcastdiff.mlir b/test/lit_tests/broadcastdiff.mlir new file mode 100644 index 00000000..ae23b7df --- /dev/null +++ b/test/lit_tests/broadcastdiff.mlir @@ -0,0 +1,16 @@ +// RUN: enzymexlamlir-opt --arith-raise %s | FileCheck %s + +module { + func.func @main(%arg0: tensor, %arg1: tensor<2xf64>) -> tensor<2xf64> { + %0 = "enzyme.broadcast"(%arg0) <{width = 2 : i64}> : (tensor) -> tensor<2xf64> + %1 = arith.addf %0, %arg1 : tensor<2xf64> + return %1 : tensor<2xf64> + } +} + +// CHECK: func.func @main(%arg0: tensor, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor) -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = stablehlo.add %[[i0:.+]], %arg1 : tensor<2xf64> +// CHECK-NEXT: return %[[i1:.+]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK-NEXT: }