Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(compiler): use dyn sized tensors in CAPI func definitions #518

Merged
merged 1 commit into from
Jul 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class SimulateTFHETypeConverter : public mlir::TypeConverter {

namespace {

mlir::RankedTensorType toDynamicTensorType(mlir::TensorType staticSizedTensor) {
std::vector<int64_t> dynSizedShape(staticSizedTensor.getShape().size(),
mlir::ShapedType::kDynamic);
return mlir::RankedTensorType::get(dynSizedShape,
staticSizedTensor.getElementType());
}

struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {

NegOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
Expand Down Expand Up @@ -131,27 +138,37 @@ struct EncodeExpandLutForBootstrapOpPattern
eeOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});

auto dynamicResultType = toDynamicTensorType(eeOp.getResult().getType());
auto dynamicLutType =
toDynamicTensorType(eeOp.getInputLookupTable().getType());

mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
eeOp.getLoc(), dynamicResultType, outputBuffer);

mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
eeOp.getLoc(),
toDynamicTensorType(eeOp.getInputLookupTable().getType()),
adaptor.getInputLookupTable());

// sim_encode_expand_lut_for_boostrap(uint64_t *out_allocated, uint64_t
// *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t
// out_stride, uint64_t *in_allocated, uint64_t *in_aligned, uint64_t
// in_offset, uint64_t in_size, uint64_t in_stride, uint32_t poly_size,
// uint32_t output_bits, bool is_signed)
if (insertForwardDeclaration(
eeOp, rewriter, funcName,
rewriter.getFunctionType({eeOp.getResult().getType(),
eeOp.getInputLookupTable().getType(),
rewriter.getIntegerType(32),
rewriter.getIntegerType(32),
rewriter.getIntegerType(1)},
{}))
rewriter.getFunctionType(
{dynamicResultType, dynamicLutType, rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
{}))
.failed()) {
return mlir::failure();
}

rewriter.create<mlir::func::CallOp>(
eeOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
polySizeCst, outputBitsCst, isSignedCst}));
mlir::ValueRange({castedOutputBuffer, castedLUT, polySizeCst,
outputBitsCst, isSignedCst}));

rewriter.replaceOp(eeOp, outputBuffer);

Expand Down Expand Up @@ -244,14 +261,19 @@ struct BootstrapGLWEOpPattern
auto inputLweDimensionCst = rewriter.create<mlir::arith::ConstantIntOp>(
bsOp.getLoc(), inputLweDimension, 32);

auto dynamicLutType = toDynamicTensorType(bsOp.getLookupTable().getType());

mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable());

// uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t
// *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t
// tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t
// poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim)
if (insertForwardDeclaration(
bsOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), bsOp.getLookupTable().getType(),
{rewriter.getIntegerType(64), dynamicLutType,
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32)},
Expand All @@ -262,7 +284,7 @@ struct BootstrapGLWEOpPattern

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
bsOp, funcName, this->getTypeConverter()->convertType(resultType),
mlir::ValueRange({adaptor.getCiphertext(), adaptor.getLookupTable(),
mlir::ValueRange({adaptor.getCiphertext(), castedLUT,
inputLweDimensionCst, polySizeCst, levelsCst,
baseLogCst, glweDimensionCst}));

Expand Down Expand Up @@ -376,7 +398,8 @@ void SimulateTFHEPass::runOnOperation() {
SimulateTFHETypeConverter converter;

target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp>();
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

Expand Down