Skip to content

Commit

Permalink
fix(compiler): use dyn sized tensors in CAPI func definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Jul 21, 2023
1 parent 7e138bf commit 545bda9
Showing 1 changed file with 34 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class SimulateTFHETypeConverter : public mlir::TypeConverter {

namespace {

mlir::RankedTensorType toDynamicTensorType(mlir::TensorType staticSizedTensor) {
std::vector<int64_t> dynSizedShape(staticSizedTensor.getShape().size(),
mlir::ShapedType::kDynamic);
return mlir::RankedTensorType::get(dynSizedShape,
staticSizedTensor.getElementType());
}

struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {

NegOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
Expand Down Expand Up @@ -131,27 +138,37 @@ struct EncodeExpandLutForBootstrapOpPattern
eeOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});

auto dynamicResultType = toDynamicTensorType(eeOp.getResult().getType());
auto dynamicLutType =
toDynamicTensorType(eeOp.getInputLookupTable().getType());

mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
eeOp.getLoc(), dynamicResultType, outputBuffer);

mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
eeOp.getLoc(),
toDynamicTensorType(eeOp.getInputLookupTable().getType()),
adaptor.getInputLookupTable());

// sim_encode_expand_lut_for_boostrap(uint64_t *out_allocated, uint64_t
// *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t
// out_stride, uint64_t *in_allocated, uint64_t *in_aligned, uint64_t
// in_offset, uint64_t in_size, uint64_t in_stride, uint32_t poly_size,
// uint32_t output_bits, bool is_signed)
if (insertForwardDeclaration(
eeOp, rewriter, funcName,
rewriter.getFunctionType({eeOp.getResult().getType(),
eeOp.getInputLookupTable().getType(),
rewriter.getIntegerType(32),
rewriter.getIntegerType(32),
rewriter.getIntegerType(1)},
{}))
rewriter.getFunctionType(
{dynamicResultType, dynamicLutType, rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
{}))
.failed()) {
return mlir::failure();
}

rewriter.create<mlir::func::CallOp>(
eeOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
polySizeCst, outputBitsCst, isSignedCst}));
mlir::ValueRange({castedOutputBuffer, castedLUT, polySizeCst,
outputBitsCst, isSignedCst}));

rewriter.replaceOp(eeOp, outputBuffer);

Expand Down Expand Up @@ -244,14 +261,19 @@ struct BootstrapGLWEOpPattern
auto inputLweDimensionCst = rewriter.create<mlir::arith::ConstantIntOp>(
bsOp.getLoc(), inputLweDimension, 32);

auto dynamicLutType = toDynamicTensorType(bsOp.getLookupTable().getType());

mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable());

// uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t
// *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t
// tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t
// poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim)
if (insertForwardDeclaration(
bsOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), bsOp.getLookupTable().getType(),
{rewriter.getIntegerType(64), dynamicLutType,
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32)},
Expand All @@ -262,7 +284,7 @@ struct BootstrapGLWEOpPattern

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
bsOp, funcName, this->getTypeConverter()->convertType(resultType),
mlir::ValueRange({adaptor.getCiphertext(), adaptor.getLookupTable(),
mlir::ValueRange({adaptor.getCiphertext(), castedLUT,
inputLweDimensionCst, polySizeCst, levelsCst,
baseLogCst, glweDimensionCst}));

Expand Down Expand Up @@ -376,7 +398,8 @@ void SimulateTFHEPass::runOnOperation() {
SimulateTFHETypeConverter converter;

target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp>();
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

Expand Down

0 comments on commit 545bda9

Please sign in to comment.