Skip to content

Commit

Permalink
feat(compiler): fusing table lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Feb 20, 2024
1 parent c53985f commit ee70d4e
Show file tree
Hide file tree
Showing 6 changed files with 829 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure, ConstantNoi
let arguments = (ins FHE_AnyEncryptedInteger:$a,
TensorOf<[AnyInteger]>:$lut);
let results = (outs FHE_AnyEncryptedInteger);

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def FHE_RoundEintOp: FHE_Op<"round", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint, ["sqMANP"]>]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure,
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", [Pure, ConstantNoise]> {
Expand Down Expand Up @@ -567,6 +568,7 @@ def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_t
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary, ["sqMANP"]>]> {
Expand Down
201 changes: 201 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,207 @@ void MulEintIntOp::getCanonicalizationPatterns(
patterns.add<ZeroEncOpPattern>(context);
}

void ApplyLookupTableEintOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {

class AfterTluPattern
: public mlir::OpRewritePattern<ApplyLookupTableEintOp> {
public:
AfterTluPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<ApplyLookupTableEintOp>(context, 0) {}

mlir::LogicalResult
matchAndRewrite(ApplyLookupTableEintOp currentOperation,
mlir::PatternRewriter &rewriter) const override {

auto intermediateValue = currentOperation.getA();
auto intermediateOperation =
llvm::dyn_cast_or_null<ApplyLookupTableEintOp>(
intermediateValue.getDefiningOp());

if (!intermediateOperation) {
return mlir::failure();
}

auto intermediateTableValue = intermediateOperation.getLut();
auto intermediateTableOperation =
llvm::dyn_cast_or_null<arith::ConstantOp>(
intermediateTableValue.getDefiningOp());

auto currentTableValue = currentOperation.getLut();
auto currentTableOperation = llvm::dyn_cast_or_null<arith::ConstantOp>(
currentTableValue.getDefiningOp());

if (!intermediateTableOperation || !currentTableOperation) {
return mlir::success();
}

auto intermediateTableContentAttr =
(intermediateTableOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>());
auto currentTableContentAttr =
(currentTableOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>());

if (!intermediateTableContentAttr || !currentTableContentAttr) {
return mlir::failure();
}

auto intermediateTableContent =
(intermediateTableContentAttr.getValues<int64_t>());
auto currentTableContent = (currentTableContentAttr.getValues<int64_t>());

auto inputValue = intermediateOperation.getA();
auto inputType = inputValue.getType().dyn_cast<FheIntegerInterface>();
auto inputBitWidth = inputType.getWidth();
auto inputIsSigned = inputType.isSigned();

auto intermediateType =
(intermediateValue.getType().dyn_cast<FheIntegerInterface>());
auto intermediateBitWidth = intermediateType.getWidth();
auto intermediateIsSigned = intermediateType.isSigned();

auto usersOfPreviousOperation = intermediateOperation->getUsers();
auto numberOfUsersOfPreviousOperation = std::distance(
usersOfPreviousOperation.begin(), usersOfPreviousOperation.end());

if (numberOfUsersOfPreviousOperation > 1) {
// This is a special case.
//
// Imagine you have this structure:
// -----------------
// x: uint6
// y: uint3 = tlu[x]
// z: uint3 = y + 1
// a: uint3 = tlu[y]
// b: uint3 = a + z
// -----------------
//
// In this case, it's be better not to fuse `a = tlu[tlu[x]]`.
//
// The reason is that intermediate `y` is necessary for `z`,
// so it has to be computed anyway.
//
// So to calculate `a`, there are 2 options:
// - fused tlu on x
// - regular tlu on y
//
// So for such cases, it's only better to fuse if the
// bit width of `x` is smaller than the bit width of `y`.

auto shouldFuse = inputBitWidth < intermediateBitWidth;
if (!shouldFuse) {
return mlir::failure();
}
}

auto intermediateTableSize = 1 << inputBitWidth;
auto currentTableSize = 1 << intermediateBitWidth;

auto newTableContent = std::vector<int64_t>();
newTableContent.reserve(intermediateTableSize);

auto lookup = [&](ssize_t index) {
if (index < 0) {
index += intermediateTableSize;
}
auto resultOfFirstLookup = intermediateTableContent[index];

// If the result of the first lookup is negative
if (resultOfFirstLookup < 0) {
// We first add the table size to preserve semantics
// e.g., table[-1] == last element in the table == tableSize + (-1)
// e.g., table[-2] == one element before that == tableSize + (-2)
resultOfFirstLookup += currentTableSize;

// If it's still negative
if (resultOfFirstLookup < 0) {
// e.g., imagine first table resulted in -100_000
// (which can exist in tables...)
// then we set it to the smalles possible value
// of the input to the table

// So if -100 is encountered on a signed 7-bit tlu
// corresponding value will be calculated as if -64 is looked up

// [0, 1, 2, 3, -4, -3, -2, -1]
// ^^ smallest value will always be in the middle

resultOfFirstLookup = currentTableSize / 2;
}
} else if (resultOfFirstLookup >= currentTableSize) {
// Another special case is the result of the first table
// being bigger than the table itself

// In this case we approach the value as the
// biggest possible value of the input to the table

if (!intermediateIsSigned) {

// So if 100 is encountered on a unsigned 6-bit tlu
// corresponding value will be calculated as if 63 is looked up

// [0, 1, 2, 3, 4, 5, 6, 7]
// ^ biggest value will always be in the end

resultOfFirstLookup = currentTableSize - 1;

} else {

// So if 100 is encountered on a signed 7-bit tlu
// corresponding value will be calculated as if 63 is looked up

// [0, 1, 2, 3, -4, -3, -2, -1]
// ^ biggest value will always be in one before the middle

resultOfFirstLookup = (currentTableSize / 2) - 1;
}
}
auto resultOfSecondLookup = currentTableContent[resultOfFirstLookup];

return resultOfSecondLookup;
};

if (!inputIsSigned) {
// unsigned lookup table structure
// [0, 1, 2, 3, 4, 5, 6, 7]
// is the identity table

// for the whole table
for (ssize_t x = 0; x < intermediateTableSize; x++) {
newTableContent.push_back(lookup(x));
}
} else {
// signed lookup table structure
// [0, 1, 2, 3, -4, -3, -2, -1]
// is the identity table

// for the positive part
for (ssize_t x = 0; x < intermediateTableSize / 2; x++) {
newTableContent.push_back(lookup(x));
}
// for the negative part
for (ssize_t x = -(intermediateTableSize / 2); x < 0; x++) {
newTableContent.push_back(lookup(x));
}
}

auto newTable = rewriter.create<arith::ConstantOp>(
currentOperation.getLoc(),
DenseIntElementsAttr::get(intermediateTableValue.getType(),
newTableContent));

auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
currentOperation.getLoc(), currentOperation.getType(), inputValue,
newTable);

rewriter.replaceAllUsesWith(currentOperation, newOperation);
return mlir::success();
}
};
patterns.add<AfterTluPattern>(context);
}

template <typename SignedConvOp>
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
mlir::MLIRContext *context) {
Expand Down
Loading

0 comments on commit ee70d4e

Please sign in to comment.