Skip to content

Commit

Permalink
mlir: Modify BatchOpInterface to allow generating more than one ope…
Browse files Browse the repository at this point in the history
…ration (#2142)

* mlir: allow batch operation to generate multiple operations

* mlir: remove support for batchSizes in CloneFunction

* clang-format
  • Loading branch information
Pangoraw authored Nov 4, 2024
1 parent e5a36ad commit f1f4d8e
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ struct ArithConstantOpBatchInterface
: public BatchOpInterface::ExternalModel<ArithConstantOpBatchInterface,
arith::ConstantOp> {

mlir::Operation *createBatch(Operation *src, IRMapping &mapper,
Operation::CloneOptions options,
std::map<Operation *, Operation *> &opMap,
ArrayRef<int64_t> batchSizes) const {
mlir::LogicalResult createBatch(Operation *src, OpBuilder &builder,
IRMapping &mapper,
ArrayRef<int64_t> batchSizes) const {

SmallVector<Type> resultTypes(src->getResultTypes().begin(),
src->getResultTypes().end());
Expand All @@ -54,7 +53,9 @@ struct ArithConstantOpBatchInterface
auto cop = mlir::Operation::create(
src->getLoc(), src->getName(), resultTypes, {}, std::move(attrs),
OpaqueProperties(nullptr), mlir::BlockRange(), 0);
return cop;
builder.insert(cop);
mapper.map(src->getResult(0), cop->getResult(0));
return success();
}
};

Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def ActivityOpInterface
def ADDataFlowOpInterface
: OpInterface<"ADDataFlowOpInterface"> {
let cppNamespace = "::mlir::enzyme";

let methods = [
InterfaceMethod<
/*desc=*/[{
Expand Down Expand Up @@ -171,11 +171,11 @@ def BatchOpInterface : OpInterface<"BatchOpInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{
Emits a batched version of a given operation.
Emits a batched version of a given operation and maps the newly created batched results to their correspondents in the original version.
}],
/*retTy=*/"::mlir::Operation*",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"createBatch",
/*args=*/(ins "::mlir::IRMapping &":$mapper, "::mlir::Operation::CloneOptions":$options, "::std::map<mlir::Operation*, mlir::Operation*>&":$opMap, "::llvm::ArrayRef<int64_t>":$batchSizes)
/*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::IRMapping &":$mapper, "::llvm::ArrayRef<int64_t>":$batchSizes)
>
];
}
Expand Down
70 changes: 13 additions & 57 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,14 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows,
llvm::ArrayRef<DIFFE_TYPE> ReturnActivity,
llvm::ArrayRef<DIFFE_TYPE> ArgActivity,
llvm::ArrayRef<int64_t> batchSizes) {
llvm::ArrayRef<DIFFE_TYPE> ArgActivity) {

SmallVector<mlir::Type, 4> RetTypes;

for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(
FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) {
if (returnPrimal) {
if (batchSizes.size()) {
auto T = cast<TensorType>(Ty);
SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
shape.append(T.getShape().begin(), T.getShape().end());
auto T2 = T.clone(shape);
RetTypes.push_back(T2);
} else {
RetTypes.push_back(Ty);
}
RetTypes.push_back(Ty);
}
if (returnShadow) {
assert(activity != DIFFE_TYPE::CONSTANT);
Expand All @@ -48,15 +39,7 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
SmallVector<mlir::Type, 4> ArgTypes;

for (auto &&[ITy, act] : llvm::zip(FTy.getInputs(), ArgActivity)) {
if (batchSizes.size()) {
auto T = cast<TensorType>(ITy);
SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
shape.append(T.getShape().begin(), T.getShape().end());
auto T2 = T.clone(shape);
ArgTypes.push_back(T2);
} else {
ArgTypes.push_back(ITy);
}
ArgTypes.push_back(ITy);
if (act == DIFFE_TYPE::DUP_ARG || act == DIFFE_TYPE::DUP_NONEED) {
ArgTypes.push_back(getShadowType(ITy, width));
} else if (act == DIFFE_TYPE::OUT_DIFF) {
Expand All @@ -81,8 +64,7 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,

Operation *clone(Operation *src, IRMapping &mapper,
Operation::CloneOptions options,
std::map<Operation *, Operation *> &opMap,
llvm::ArrayRef<int64_t> batchSizes) {
std::map<Operation *, Operation *> &opMap) {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;

Expand All @@ -101,31 +83,17 @@ Operation *clone(Operation *src, IRMapping &mapper,
// Create the new operation.
Operation *newOp = nullptr;

if (batchSizes.size())
if (auto ifaceOp = dyn_cast<BatchOpInterface>(src)) {
newOp = ifaceOp.createBatch(mapper, options, opMap, batchSizes);
}

if (!newOp) {
SmallVector<Type> resultTypes(src->getResultTypes().begin(),
src->getResultTypes().end());
if (batchSizes.size()) {
for (auto &Ty : resultTypes) {
auto T = cast<TensorType>(Ty);
SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
shape.append(T.getShape().begin(), T.getShape().end());
Ty = T.clone(shape);
}
}
newOp = Operation::create(
src->getLoc(), src->getName(), resultTypes, operands, src->getAttrs(),
OpaqueProperties(nullptr), successors, src->getNumRegions());

// Clone the regions.
if (options.shouldCloneRegions()) {
for (unsigned i = 0; i != src->getNumRegions(); ++i)
cloneInto(&src->getRegion(i), &newOp->getRegion(i), mapper, opMap,
batchSizes);
cloneInto(&src->getRegion(i), &newOp->getRegion(i), mapper, opMap);
}
}

Expand All @@ -138,15 +106,13 @@ Operation *clone(Operation *src, IRMapping &mapper,
}

void cloneInto(Region *src, Region *dest, IRMapping &mapper,
std::map<Operation *, Operation *> &opMap,
llvm::ArrayRef<int64_t> batchSizes) {
cloneInto(src, dest, dest->end(), mapper, opMap, batchSizes);
std::map<Operation *, Operation *> &opMap) {
cloneInto(src, dest, dest->end(), mapper, opMap);
}

/// Clone this region into 'dest' before the given position in 'dest'.
void cloneInto(Region *src, Region *dest, Region::iterator destPos,
IRMapping &mapper, std::map<Operation *, Operation *> &opMap,
llvm::ArrayRef<int64_t> batchSizes) {
IRMapping &mapper, std::map<Operation *, Operation *> &opMap) {
assert(src);
assert(dest && "expected valid region to clone into");
assert(src != dest && "cannot clone region into itself");
Expand Down Expand Up @@ -177,12 +143,6 @@ void cloneInto(Region *src, Region *dest, Region::iterator destPos,
for (auto arg : block.getArguments())
if (!mapper.contains(arg)) {
auto Ty = arg.getType();
if (batchSizes.size()) {
auto T = cast<TensorType>(Ty);
SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
shape.append(T.getShape().begin(), T.getShape().end());
Ty = T.clone(shape);
}
mapper.map(arg, newBlock->addArgument(Ty, arg.getLoc()));
}

Expand All @@ -205,8 +165,7 @@ void cloneInto(Region *src, Region *dest, Region::iterator destPos,
Block &clonedBlock = std::get<1>(zippedBlocks);
// Clone and remap the operations within this block.
for (Operation &op : sourceBlock) {
clonedBlock.push_back(
clone(&op, mapper, cloneOptions, opMap, batchSizes));
clonedBlock.push_back(clone(&op, mapper, cloneOptions, opMap));
}
}

Expand All @@ -226,8 +185,7 @@ void cloneInto(Region *src, Region *dest, Region::iterator destPos,
clone.setOperands(operands);

for (auto regions : llvm::zip(source.getRegions(), clone.getRegions()))
cloneInto(&std::get<0>(regions), &std::get<1>(regions), mapper, opMap,
batchSizes);
cloneInto(&std::get<0>(regions), &std::get<1>(regions), mapper, opMap);
}
}
}
Expand All @@ -241,14 +199,13 @@ FunctionOpInterface CloneFunctionWithReturns(
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> RetActivity,
Twine name, IRMapping &VMap, std::map<Operation *, Operation *> &OpMap,
mlir::Type additionalArg, llvm::ArrayRef<int64_t> batchSizes) {
mlir::Type additionalArg) {
assert(!F.getFunctionBody().empty());
// F = preprocessForClone(F, mode);
// llvm::ValueToValueMapTy VMap;
auto FTy = getFunctionTypeForClone(
F.getFunctionType().cast<mlir::FunctionType>(), mode, width,
additionalArg, returnPrimals, returnShadows, RetActivity, ArgActivity,
batchSizes);
additionalArg, returnPrimals, returnShadows, RetActivity, ArgActivity);

/*
for (Block &BB : F.getFunctionBody().getBlocks()) {
Expand All @@ -271,8 +228,7 @@ FunctionOpInterface CloneFunctionWithReturns(
table.insert(NewF);
SymbolTable::setSymbolVisibility(NewF, SymbolTable::Visibility::Private);

cloneInto(&F.getFunctionBody(), &NewF.getFunctionBody(), VMap, OpMap,
batchSizes);
cloneInto(&F.getFunctionBody(), &NewF.getFunctionBody(), VMap, OpMap);

{
auto &blk = NewF.getFunctionBody().front();
Expand Down
14 changes: 5 additions & 9 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,17 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
llvm::ArrayRef<bool> returnPrimals,
llvm::ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> ReturnActivity,
llvm::ArrayRef<DIFFE_TYPE> ArgActivity,
llvm::ArrayRef<int64_t> batchSizes = {});
llvm::ArrayRef<DIFFE_TYPE> ArgActivity);

void cloneInto(Region *src, Region *dest, Region::iterator destPos,
IRMapping &mapper, std::map<Operation *, Operation *> &opMap,
llvm::ArrayRef<int64_t> batchSizes);
IRMapping &mapper, std::map<Operation *, Operation *> &opMap);

void cloneInto(Region *src, Region *dest, IRMapping &mapper,
std::map<mlir::Operation *, mlir::Operation *> &opMap,
llvm::ArrayRef<int64_t> batchSizes);
std::map<mlir::Operation *, mlir::Operation *> &opMap);

Operation *clone(Operation *src, IRMapping &mapper,
Operation::CloneOptions options,
std::map<Operation *, Operation *> &opMap,
llvm::ArrayRef<int64_t> batchSizes);
std::map<Operation *, Operation *> &opMap);

FunctionOpInterface CloneFunctionWithReturns(
DerivativeMode mode, unsigned width, FunctionOpInterface F,
Expand All @@ -55,4 +51,4 @@ FunctionOpInterface CloneFunctionWithReturns(
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> ReturnActivity,
Twine name, IRMapping &VMap, std::map<Operation *, Operation *> &OpMap,
mlir::Type additionalArg, llvm::ArrayRef<int64_t> batchSizes = {});
mlir::Type additionalArg);
Loading

0 comments on commit f1f4d8e

Please sign in to comment.