From 8d3ef084f645046fdd2d6aa9f144aafd1741e4d2 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Thu, 28 Nov 2024 15:50:04 +0100 Subject: [PATCH 01/12] add type conversions for width != 1. This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ``` --- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 7 +++++-- enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index d2d6ddfe19b..247dc5fe0a5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -45,8 +45,11 @@ class FloatTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + if (width > 1) { + return RankedTensorType::get({width}, self); + } else { + return self; + } } bool isMutable(Type self) const { return false; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 69cfad436cf..8a9057a5853 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -27,7 +27,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode, for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip( FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) { if (returnPrimal) { - RetTypes.push_back(Ty); + if (width != 1) { + RetTypes.push_back(mlir::RankedTensorType::get({width}, Ty)); + } else { + RetTypes.push_back(Ty); + } } if (returnShadow) { assert(activity != DIFFE_TYPE::CONSTANT); @@ -39,7 +43,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode, SmallVector ArgTypes; for (auto &&[ITy, act] : llvm::zip(FTy.getInputs(), ArgActivity)) { - ArgTypes.push_back(ITy); + if (width != 1) { + ArgTypes.push_back(mlir::RankedTensorType::get({width}, ITy)); + } else { + ArgTypes.push_back(ITy); + } if (act == DIFFE_TYPE::DUP_ARG || act == DIFFE_TYPE::DUP_NONEED) { ArgTypes.push_back(getShadowType(ITy, width)); } else if (act == DIFFE_TYPE::OUT_DIFF) { @@ -232,6 +240,11 @@ FunctionOpInterface CloneFunctionWithReturns( { auto &blk = NewF.getFunctionBody().front(); + if (width != 1) { + for (auto &arg : blk.getArguments()) { + arg.setType(mlir::RankedTensorType::get({width}, arg.getType())); + } + } assert(F.getFunctionBody().front().getNumArguments() == ArgActivity.size()); for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) { mlir::Value oval = F.getFunctionBody().front().getArgument(i); From 5e647e5a79e74fba684eff5db81616fa520fe2a0 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Thu, 28 Nov 2024 16:07:53 +0100 Subject: [PATCH 02/12] add code to tblgen generator, this eventually needs to be a single function call. --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 900c5c813cd..221591dde54 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1819,6 +1819,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " LogicalResult createForwardModeTangent(Operation *op0, " "OpBuilder &builder, MGradientUtils *gutils) const {\n"; os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; + os << " if (gutils->width != 1) {\n" + << " auto newop = gutils->getNewFromOriginal(op0);\n" + << " for (auto res : newop->getResults()) {\n" + << " res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType()));\n" + << " }\n" + << " }\n"; origName = "op"; break; } From f69f97443e38b268d2d434118197cbbe5bef76e5 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Thu, 28 Nov 2024 16:51:26 +0100 Subject: [PATCH 03/12] a test and formatting --- enzyme/test/MLIR/ForwardMode/test_vector.mlir | 26 +++++++++++++++++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/MLIR/ForwardMode/test_vector.mlir diff --git a/enzyme/test/MLIR/ForwardMode/test_vector.mlir b/enzyme/test/MLIR/ForwardMode/test_vector.mlir new file mode 100644 index 00000000000..31bd6d8051b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/test_vector.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +// TODO: actually test + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } + func.func @dsq(%x : tensor<2xf64>, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<2xf64>, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<2xf64>, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<2xf64>, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 221591dde54..6e9d8254fb4 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1822,7 +1822,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " if (gutils->width != 1) {\n" << " auto newop = gutils->getNewFromOriginal(op0);\n" << " for (auto res : newop->getResults()) {\n" - << " res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType()));\n" + << " res.setType(mlir::RankedTensorType::get({gutils->width}, " + "res.getType()));\n" << " }\n" << " }\n"; origName = "op"; From 623fccfc6155dc732a4b068e39ba3ddcd7896144 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Mon, 2 Dec 2024 11:10:13 +0100 Subject: [PATCH 04/12] use tensor splatop --- .../ArithAutoDiffOpInterfaceImpl.cpp | 8 +++++++ .../CoreDialectsAutoDiffImplementations.cpp | 1 + .../CoreDialectsAutoDiffImplementations.h | 1 + .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 21 ++++--------------- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/Passes.h | 5 +++++ enzyme/Enzyme/MLIR/Passes/Passes.td | 3 ++- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + enzyme/test/MLIR/ForwardMode/test_vector.mlir | 14 ++++++------- 9 files changed, 30 insertions(+), 25 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp index 9b27503d79d..8d3650969d0 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp @@ -17,6 +17,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface( arith::ConstantOp::attachInterface(*context); }); } + +void mlir::enzyme::registerTensorDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 355808cdbcc..9fb5e93e860 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -432,4 +432,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); enzyme::registerFuncDialectAutoDiffInterface(registry); + enzyme::registerTensorDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index d6f28ccfc73..650f6c6326b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); +void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 8a9057a5853..1d186018f5c 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -27,11 +27,7 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode, for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip( FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) { if (returnPrimal) { - if (width != 1) { - RetTypes.push_back(mlir::RankedTensorType::get({width}, Ty)); - } else { - RetTypes.push_back(Ty); - } + RetTypes.push_back(Ty); } if (returnShadow) { assert(activity != DIFFE_TYPE::CONSTANT); @@ -43,11 +39,7 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode, SmallVector ArgTypes; for (auto &&[ITy, act] : llvm::zip(FTy.getInputs(), ArgActivity)) { - if (width != 1) { - ArgTypes.push_back(mlir::RankedTensorType::get({width}, ITy)); - } else { - ArgTypes.push_back(ITy); - } + ArgTypes.push_back(ITy); if (act == DIFFE_TYPE::DUP_ARG || act == DIFFE_TYPE::DUP_NONEED) { ArgTypes.push_back(getShadowType(ITy, width)); } else if (act == DIFFE_TYPE::OUT_DIFF) { @@ -240,11 +232,6 @@ FunctionOpInterface CloneFunctionWithReturns( { auto &blk = NewF.getFunctionBody().front(); - if (width != 1) { - for (auto &arg : blk.getArguments()) { - arg.setType(mlir::RankedTensorType::get({width}, arg.getType())); - } - } assert(F.getFunctionBody().front().getNumArguments() == ArgActivity.size()); for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) { mlir::Value oval = F.getFunctionBody().front().getArgument(i); @@ -258,9 +245,9 @@ FunctionOpInterface CloneFunctionWithReturns( mlir::Value val = blk.getArgument(i); mlir::Value dval; if (i == ArgActivity.size() - 1) - dval = blk.addArgument(val.getType(), val.getLoc()); + dval = blk.addArgument(getShadowType(val.getType(), width), val.getLoc()); else - dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(), + dval = blk.insertArgument(blk.args_begin() + i + 1, getShadowType(val.getType(), width), val.getLoc()); ptrInputs.map(oval, dval); } diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 0445fc43064..99db4d80034 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms MLIRFuncDialect MLIRFuncTransforms MLIRGPUDialect + MLIRTensorDialect MLIRIR MLIRLLVMDialect MLIRMathDialect diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 58c43be236d..fb6df3e2208 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Dialect/Dialect.h" @@ -80,6 +81,10 @@ namespace affine { class AffineDialect; } // end namespace affine +namespace tensor { +class TensorDialect; +} // end namespace tensor + namespace LLVM { class LLVMDialect; } // end namespace LLVM diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 6458e63b273..c5b4df76917 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> { let dependentDialects = [ "arith::ArithDialect", "complex::ComplexDialect", - "cf::ControlFlowDialect" + "cf::ControlFlowDialect", + "tensor::TensorDialect", ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 0e6bdf7b101..99e7243129b 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -67,6 +67,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); diff --git a/enzyme/test/MLIR/ForwardMode/test_vector.mlir b/enzyme/test/MLIR/ForwardMode/test_vector.mlir index 31bd6d8051b..1aa1f9621fd 100644 --- a/enzyme/test/MLIR/ForwardMode/test_vector.mlir +++ b/enzyme/test/MLIR/ForwardMode/test_vector.mlir @@ -1,14 +1,12 @@ // RUN: %eopt --enzyme %s | FileCheck %s -// TODO: actually test - module { func.func @square(%x : f64) -> f64{ %y = arith.mulf %x, %x : f64 return %y : f64 } - func.func @dsq(%x : tensor<2xf64>, %dx : tensor<2xf64>) -> tensor<2xf64> { - %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<2xf64>, tensor<2xf64>) -> (tensor<2xf64>) + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) return %r : tensor<2xf64> } } @@ -17,9 +15,11 @@ module { // CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64> // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } -// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<2xf64>, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : tensor<2xf64> -// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : tensor<2xf64> +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[s0:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> // CHECK-NEXT: return %[[i2]] : tensor<2xf64> From f860e196a8fa4c0bda4d433b1653d96fd1ccd2f7 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Mon, 2 Dec 2024 11:51:03 +0100 Subject: [PATCH 05/12] remove stale enzyme-tblgen changes --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 6e9d8254fb4..900c5c813cd 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1819,13 +1819,6 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " LogicalResult createForwardModeTangent(Operation *op0, " "OpBuilder &builder, MGradientUtils *gutils) const {\n"; os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; - os << " if (gutils->width != 1) {\n" - << " auto newop = gutils->getNewFromOriginal(op0);\n" - << " for (auto res : newop->getResults()) {\n" - << " res.setType(mlir::RankedTensorType::get({gutils->width}, " - "res.getType()));\n" - << " }\n" - << " }\n"; origName = "op"; break; } From 3a3bdf86630b8b82cb911c6cc240b2374c76347e Mon Sep 17 00:00:00 2001 From: jumerckx Date: Wed, 4 Dec 2024 11:26:59 +0100 Subject: [PATCH 06/12] do the simple batching in enzyme-tblgen --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 900c5c813cd..849e1de3093 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -275,8 +275,18 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, os << ord; } if (!vecValue && !startsWith(ord, "local")) { - if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << "if (gutils->width != 1) {\n" + << " " << argName << "_" << (idx - 1) << " = builder.create(\n" + << " op.getLoc(),\n" + << " mlir::RankedTensorType::get({gutils->width}, " << argName << "_" << (idx - 1) << ".getType()),\n" + << " " << argName << "_" << (idx - 1) << ");\n" + << "}"; + } + } if (lookup && intrinsic != MLIRDerivatives) os << ", " << builder << ")"; From f1b2a6d882b1c472b2d4cbac4a099ef24be548ca Mon Sep 17 00:00:00 2001 From: jumerckx Date: Wed, 4 Dec 2024 11:27:28 +0100 Subject: [PATCH 07/12] include tensor in all AutoDiffOpInterfaceImpls --- .../Implementations/AffineAutoDiffOpInterfaceImpl.cpp | 1 + .../Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp | 1 + .../MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp | 1 + .../Implementations/ComplexAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp | 1 + .../Implementations/LinalgAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp | 8 ++++++++ .../Implementations/MemRefAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp | 1 + 11 files changed, 18 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp index c27f0d60d12..4208287cb80 100644 --- a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp @@ -14,6 +14,7 @@ #include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Interfaces/AutoDiffOpInterface.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IntegerSet.h" using namespace mlir; diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 247dc5fe0a5..5f416b2ffcf 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -15,6 +15,7 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp index 8f40db9d834..b8a9484f3e2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp @@ -18,6 +18,7 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp index eceddc03320..2bf22f4fcf5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp @@ -18,6 +18,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp index 5308304f5b7..25243bb56b2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -16,6 +16,7 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index b9e9ade7421..48bd2397810 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -16,6 +16,7 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 3a72c3a5d35..1a826742cbb 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp index 2833eeb4472..1a7a40f5eed 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp @@ -17,6 +17,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -36,3 +37,10 @@ void mlir::enzyme::registerMathDialectAutoDiffInterface( registerInterfaces(context); }); } + +// void mlir::enzyme::registerTensorDialectAutoDiffInterface( +// DialectRegistry ®istry) { +// registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { +// registerInterfaces(context); +// }); +// } diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index cdee04b7bf2..2b0dfaa60dd 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -18,6 +18,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp index 4d8116ce011..3db78055895 100644 --- a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp @@ -16,6 +16,7 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index cbecedc2182..a92ab7ace31 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -18,6 +18,7 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" From 396a29b8efe1f13efdb5aaf11c452affc169ae1b Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 6 Dec 2024 16:46:32 +0100 Subject: [PATCH 08/12] add enzyme broadcastop --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 16 ++++++++++++++++ enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index be139fb3d8b..3c359491fb3 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> { } +def BroadcastOp : Enzyme_Op<"broadcast"> { + let description = [{ + Broadcast the operand by adding an extra dimension the frond with a size equal to the width attribute. + For scalar operands, a one-dimensional ranked tensor is created. + + NOTE: Only works for scalars and *ranked* tensors for now. + }]; + + let arguments = (ins AnyType:$input, I64Attr:$width); + let results = (outs AnyRankedTensor:$output); + + let builders = [ + OpBuilder<(ins "Value":$input, "int64_t":$width)> + ]; +} + #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 3e318542730..971fc0fdc5a 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -191,3 +191,23 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, int64_t width) { + auto widthAttr = builder.getI64IntegerAttr(width); + RankedTensorType output; + // TODO: support things other than scalars and ranked tensors, maybe reuse getShadowType here? + if (auto tensorType = input.getType().dyn_cast()) { + auto shape = tensorType.getShape(); + SmallVector newShape; + newShape.push_back(width); + newShape.append(shape.begin(), shape.end()); + output = RankedTensorType::get(newShape, tensorType.getElementType()); + } else { + output = RankedTensorType::get({width}, input.getType()); + } + build(builder, result, output, input, widthAttr); +} From 64489a5662e6daec445141686ac4cede52468069 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 6 Dec 2024 16:46:48 +0100 Subject: [PATCH 09/12] getShadowType for TensorTypeInterface --- .../Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 5f416b2ffcf..983158812c6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -110,7 +110,14 @@ class TensorTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); + if (width != 1) { + auto tenType = self.cast(); + auto shape = tenType.getShape(); + SmallVector newShape; + newShape.push_back(width); + newShape.append(shape.begin(), shape.end()); + return RankedTensorType::get(newShape, tenType.getElementType()); + } return self; } From 0ccb37f7b0d043a04a944a6e27e2877e39569577 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 6 Dec 2024 16:54:59 +0100 Subject: [PATCH 10/12] create broadcastop in enzyme-tblgen --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 849e1de3093..64e2d9de320 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -280,10 +280,10 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, if (intrinsic == MLIRDerivatives) { os << ";\n"; os << "if (gutils->width != 1) {\n" - << " " << argName << "_" << (idx - 1) << " = builder.create(\n" + << " " << argName << "_" << (idx - 1) << " = builder.create(\n" << " op.getLoc(),\n" - << " mlir::RankedTensorType::get({gutils->width}, " << argName << "_" << (idx - 1) << ".getType()),\n" - << " " << argName << "_" << (idx - 1) << ");\n" + << " " << argName << "_" << (idx - 1) << ",\n" + << " gutils->width);\n" << "}"; } } From 123a98584a21c47a632c8f170159e702e2187164 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 6 Dec 2024 16:56:40 +0100 Subject: [PATCH 11/12] Revert "include tensor in all AutoDiffOpInterfaceImpls" This reverts commit c06ed01709b51bff5b794a7e4dc83b63510b9a84. --- .../Implementations/AffineAutoDiffOpInterfaceImpl.cpp | 1 - .../Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp | 1 - .../MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp | 1 - .../Implementations/ComplexAutoDiffOpInterfaceImpl.cpp | 1 - .../MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp | 1 - .../MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp | 1 - .../Implementations/LinalgAutoDiffOpInterfaceImpl.cpp | 1 - .../MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp | 8 -------- .../Implementations/MemRefAutoDiffOpInterfaceImpl.cpp | 1 - .../MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp | 1 - .../MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp | 1 - 11 files changed, 18 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp index 4208287cb80..c27f0d60d12 100644 --- a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp @@ -14,7 +14,6 @@ #include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Interfaces/AutoDiffOpInterface.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IntegerSet.h" using namespace mlir; diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 983158812c6..7c72b97d093 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -15,7 +15,6 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp index b8a9484f3e2..8f40db9d834 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp @@ -18,7 +18,6 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp index 2bf22f4fcf5..eceddc03320 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp @@ -18,7 +18,6 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp index 25243bb56b2..5308304f5b7 100644 --- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -16,7 +16,6 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 48bd2397810..b9e9ade7421 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -16,7 +16,6 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 1a826742cbb..3a72c3a5d35 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -21,7 +21,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp index 1a7a40f5eed..2833eeb4472 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp @@ -17,7 +17,6 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -37,10 +36,3 @@ void mlir::enzyme::registerMathDialectAutoDiffInterface( registerInterfaces(context); }); } - -// void mlir::enzyme::registerTensorDialectAutoDiffInterface( -// DialectRegistry ®istry) { -// registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { -// registerInterfaces(context); -// }); -// } diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 2b0dfaa60dd..cdee04b7bf2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -18,7 +18,6 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp index 3db78055895..4d8116ce011 100644 --- a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp @@ -16,7 +16,6 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index a92ab7ace31..cbecedc2182 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -18,7 +18,6 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" From 3d2911d464d68e603823642c35e386622fe5fc09 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 6 Dec 2024 17:08:04 +0100 Subject: [PATCH 12/12] test --- .../{test_vector.mlir => batched_scalar.mlir} | 8 +++--- .../test/MLIR/ForwardMode/batched_tensor.mlir | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) rename enzyme/test/MLIR/ForwardMode/{test_vector.mlir => batched_scalar.mlir} (74%) create mode 100644 enzyme/test/MLIR/ForwardMode/batched_tensor.mlir diff --git a/enzyme/test/MLIR/ForwardMode/test_vector.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir similarity index 74% rename from enzyme/test/MLIR/ForwardMode/test_vector.mlir rename to enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index 1aa1f9621fd..09f85c5f68a 100644 --- a/enzyme/test/MLIR/ForwardMode/test_vector.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -11,14 +11,14 @@ module { } } -// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<2xf64>, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[s0:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> -// CHECK-NEXT: %[[s1:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir new file mode 100644 index 00000000000..5895a8bad24 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = arith.mulf %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) + return %r : tensor<2x10xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> +// CHECK-NEXT: } \ No newline at end of file