From a6cbcbe43038d3025725ae27093ed0141fb22157 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Wed, 20 Nov 2024 09:10:36 +0100 Subject: [PATCH] add width attribute to autodiff as well --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 5 +++-- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index b1b350c12d0..77755bf4b49 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -79,7 +79,8 @@ def PlaceholderOp : Enzyme_Op<"placeholder", def ForwardDiffOp : Enzyme_Op<"fwddiff", [DeclareOpInterfaceMethods]> { let summary = "Perform forward mode AD on a funcop"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); let results = (outs Variadic:$outputs); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); + let results = (outs Variadic:$outputs); let assemblyFormat = [{ $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) @@ -89,7 +90,7 @@ def ForwardDiffOp : Enzyme_Op<"fwddiff", def AutoDiffOp : Enzyme_Op<"autodiff", [DeclareOpInterfaceMethods]> { let summary = "Perform reverse mode AD on a funcop"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width); let results = (outs Variadic:$outputs); let assemblyFormat = [{ diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index f3c8d37774b..c3fe53a7c4e 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -259,7 +259,7 @@ struct DifferentiatePass : public DifferentiatePassBase { MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); bool freeMemory = true; - size_t width = 1; + size_t width = CI.getWidth(); std::vector volatile_args; for (auto &a : fn.getFunctionBody().getArguments()) {