Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Batched autodiff #2181

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft

(WIP) Batched autodiff #2181

wants to merge 12 commits into from

Conversation

jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Nov 28, 2024

Added some type conversions to tensor types if width != 1. The simple test case seems correct now.
Corresponding Enzyme-JAX PR: EnzymeAD/Enzyme-JAX#197

@@ -27,7 +27,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(
FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) {
if (returnPrimal) {
RetTypes.push_back(Ty);
if (width != 1) {
RetTypes.push_back(mlir::RankedTensorType::get({width}, Ty));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn’t need changing since the primal is always unmodified, only Derivatives are changed (and we should be pushing the getshadow types for those below)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, then I'm confused of what batched autodiff is.
How should my testcase change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm, it clicked. It's just the shadow that's batched 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tho perhaps mul will be more illustrative, https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Enzyme/ForwardModeVector/mul.ll (and obviously feel free to look at any/all of the other examples

@jumerckx
Copy link
Collaborator Author

jumerckx commented Dec 2, 2024

I haven't yet fully made the changes in enzyme-tblgen.cpp, and either way this just works for the simple test case.
But I added the following manually in ArithDerivatives.inc.

mlir::Value itmp = ({
  // Computing MulFOp
  auto fwdarg_0 = dif;
  auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
  if (gutils->width != 1)
  {
    fwdarg_1 = builder.create<tensor::SplatOp>(
        op.getLoc(),
        mlir::RankedTensorType::get({gutils->width},
                                    fwdarg_1.getType()),
        fwdarg_1);
  }
  builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
});

But this is the MLIR code that is generated for this simple test:

  func.func private @fwddiffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
    %splat = tensor.splat %arg0 : tensor<2xf64>
    %0 = arith.mulf %arg1, %splat : tensor<2xf64>
    %splat_0 = tensor.splat %arg0 : tensor<2xf64>
    %1 = arith.mulf %arg1, %splat_0 : tensor<2xf64>
    %2 = arith.addf %0, %1 : tensor<2xf64>
    %3 = arith.mulf %arg0, %arg0 : f64
    return %2 : tensor<2xf64>
  }

This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this:
```
  LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const
  {
    auto op = cast<arith::MulFOp>(op0);
    if (gutils->width != 1) {
      auto newop = gutils->getNewFromOriginal(op0);
      for (auto res : newop->getResults()) {
        res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType()));
      }
    }
    gutils->eraseIfUnused(op);
    if (gutils->isConstantInstruction(op))
      return success();
    mlir::Value res = nullptr;
    if (!gutils->isConstantValue(op->getOperand(0)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(0), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          // TODO: gutils->makeBatched(...)
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        itmp.dump();
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    if (!gutils->isConstantValue(op->getOperand(1)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(1), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    assert(res);
    gutils->setDiffe(op->getResult(0), res, builder);
    return success();
  }
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants