diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 56242b413c..9ab387127c 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -19,7 +19,7 @@ include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td" class FHE_Op traits = []> : Op; -def FHE_ZeroEintOp : FHE_Op<"zero", [Pure, ConstantNoise]> { +def FHE_ZeroEintOp : FHE_Op<"zero", [Pure, ZeroNoise]> { let summary = "Returns a trivial encrypted integer of 0"; let description = [{ @@ -34,7 +34,7 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [Pure, ConstantNoise]> { let results = (outs FHE_AnyEncryptedInteger:$out); } -def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ConstantNoise]> { +def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ZeroNoise]> { let summary = "Creates a new tensor with all elements initialized to an encrypted zero."; let description = [{ @@ -52,7 +52,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ConstantNoise]> { let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); } -def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { +def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, AdditiveNoise, DeclareOpInterfaceMethods]> { let summary = "Adds an encrypted integer and a clear integer"; let description = [{ @@ -85,7 +85,7 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, DeclareOpInt let hasFolder = 1; } -def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods]> { +def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, AdditiveNoise, DeclareOpInterfaceMethods]> { let summary = "Adds two encrypted integers"; let description = [{ @@ -117,7 +117,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, DeclareOpInterfaceMeth let hasVerifier = 1; } -def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint]> { +def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint, AdditiveNoise]> { let summary = "Subtract an encrypted integer from a clear integer"; let description = [{ @@ -149,7 +149,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint]> { let hasVerifier = 1; } -def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { +def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, AdditiveNoise, DeclareOpInterfaceMethods]> { let summary = "Subtract a clear integer from an encrypted integer"; let description = [{ @@ -182,7 +182,7 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, DeclareOpInt let hasFolder = 1; } -def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods]> { +def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, AdditiveNoise, DeclareOpInterfaceMethods]> { let summary = "Subtract an encrypted integer from an encrypted integer"; let description = [{ @@ -214,7 +214,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, DeclareOpInterfaceMeth let hasVerifier = 1; } -def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { +def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, AdditiveNoise, DeclareOpInterfaceMethods]> { let summary = "Negates an encrypted integer"; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td index dc9f531223..e60be26b8b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td @@ -37,6 +37,22 @@ def ConstantNoise : OpInterface<"ConstantNoise"> { let cppNamespace = "mlir::concretelang::FHE"; } +def ZeroNoise : OpInterface<"ZeroNoise"> { + let description = [{ + An operation outputs a ciphertext with zero noise. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; +} + +def AdditiveNoise : OpInterface<"AdditiveNoise"> { + let description = [{ + An n-ary operation whose output noise is the unweighted sum of all input noises. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; +} + def UnaryEint : OpInterface<"UnaryEint"> { let description = [{ A unary operation on scalars, with the operand encrypted. @@ -63,7 +79,7 @@ def UnaryEint : OpInterface<"UnaryEint"> { if (auto operandTy = dyn_cast($_op->getOpOperand(0).get().getType())) { return operandTy.getElementType(); } else return $_op->getOpOperand(0).get().getType(); - }]> + }]> ]; } @@ -124,8 +140,8 @@ def Binary : OpInterface<"Binary"> { if (auto cstOp = llvm::dyn_cast_or_null($_op-> getOpOperand(opNum).get().getDefiningOp())) return cstOp->template getAttrOfType("value").template getValues(); - else return {}; - }]>, + else return {}; + }]>, ]; } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 82313a5af7..3aae289cd1 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -11,6 +11,7 @@ #include "boost/outcome.h" +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -190,6 +191,11 @@ struct FunctionToDag { } else if (auto matmulEintEint = asMatmulEintEint(op)) { addEncMatMulTensor(matmulEintEint, encrypted_inputs, precision); return; + } else if (auto zero = asZeroNoise(op)) { + // special case as zero are rewritten in several optimizer nodes + index = addZeroNoise(zero); + } else if (auto additive = asAdditiveNoise(op)) { + index = addAdditiveNoise(additive, encrypted_inputs); } else { index = addLevelledOp(op, encrypted_inputs); } @@ -259,6 +265,45 @@ struct FunctionToDag { return loc; } + concrete_optimizer::dag::OperatorIndex + addZeroNoise(concretelang::FHE::ZeroNoise &op) { + auto val = op->getOpResult(0); + auto outShape = getShape(val); + + // Trivial encrypted constants encoding + // There are converted to input + levelledop + auto precision = fhe::utils::getEintPrecision(val); + auto opI = dagBuilder.add_input(precision, slice(outShape)); + auto inputs = Inputs{opI}; + + // Default complexity is negligible + double const fixedCost = NEGLIGIBLE_COMPLEXITY; + double const lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; + auto loc = loc_to_string(op.getLoc()); + auto comment = std::string(op->getName().getStringRef()) + " " + loc; + auto weights = std::vector{1.}; + index[val] = + dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + slice(weights), slice(outShape), comment); + return index[val]; + } + + concrete_optimizer::dag::OperatorIndex + addAdditiveNoise(concretelang::FHE::AdditiveNoise &op, Inputs &inputs) { + auto val = op->getResult(0); + auto out_shape = getShape(val); + // Default complexity is negligible + double fixed_cost = NEGLIGIBLE_COMPLEXITY; + double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; + auto loc = loc_to_string(op.getLoc()); + auto comment = std::string(op->getName().getStringRef()) + " " + loc; + auto weights = std::vector(inputs.size(), 1.); + index[val] = dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor, + fixed_cost, slice(weights), + slice(out_shape), comment); + return index[val]; + } + concrete_optimizer::dag::OperatorIndex addLevelledOp(mlir::Operation &op, Inputs &inputs) { auto val = op.getResult(0); @@ -279,9 +324,38 @@ struct FunctionToDag { // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); auto comment = std::string(op.getName().getStringRef()) + " " + loc; - index[val] = - dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor, - fixed_cost, manp, slice(out_shape), comment); + + double maxInputManp = 0.; + size_t n_inputs = 0; + for (auto input : op.getOperands()) { + if (!fhe::utils::isEncryptedValue(input)) { + continue; + } + n_inputs += 1; + if (input.isa()) { + maxInputManp = fmax(1., maxInputManp); + } else { + auto inpSmanpInt = + input.getDefiningOp()->getAttrOfType("SMANP"); + const double inpManp = sqrt(inpSmanpInt.getValue().roundToDouble()); + maxInputManp = fmax(inpManp, maxInputManp); + } + } + assert(inputs.size() == n_inputs); + double weight; + if (maxInputManp == 0) { + // The max input manp is zero, meaning the inputs are all zero tensors + // with no noise. In this case it does not matter the weight since it will + // multiply zero. + weight = 0.; + } else { + weight = manp / maxInputManp; + assert(!std::isnan(weight)); + } + auto weights = std::vector(n_inputs, weight); + index[val] = dagBuilder.add_levelled_op(slice(inputs), lwe_dim_cost_factor, + fixed_cost, slice(weights), + slice(out_shape), comment); return index[val]; } @@ -336,46 +410,20 @@ struct FunctionToDag { mlir::Value result = mulOp.getResult(); const std::vector resultShape = getShape(result); - Operation *xOp = mulOp.getLhs().getDefiningOp(); - Operation *yOp = mulOp.getRhs().getDefiningOp(); - const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; - llvm::APInt xSmanp = llvm::APInt{1, 1, false}; - if (xOp != nullptr) { - const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); - assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); - xSmanp = xSmanpAttr.getValue(); - } - - llvm::APInt ySmanp = llvm::APInt{1, 1, false}; - if (yOp != nullptr) { - const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); - assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); - ySmanp = ySmanpAttr.getValue(); - } - auto loc = loc_to_string(mulOp.getLoc()); auto comment = std::string(mulOp->getName().getStringRef()) + " " + loc; - // (x + y) and (x - y) - const double addSubManp = - sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); - - // tlu(v) - const double tluManp = 1; - - // tlu(v1) - tlu(v2) - const double tluSubManp = sqrt(tluManp + tluManp); - // for tlus const std::vector unknownFunction; // tlu(x + y) - auto addNode = - dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(resultShape), comment); + auto addWeights = std::vector{1, 1}; + auto addNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, + fixedCost, slice(addWeights), + slice(resultShape), comment); std::optional lhsCorrectionNode; if (isSignedEint(mulOp.getType())) { // If signed mul we need to add the addition node for correction of the @@ -390,9 +438,10 @@ struct FunctionToDag { dagBuilder.add_lut(addNode, slice(unknownFunction), precision); // tlu(x - y) - auto subNode = - dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(resultShape), comment); + auto subWeights = std::vector{1, 1}; + auto subNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, + fixedCost, slice(subWeights), + slice(resultShape), comment); // This is a signed tlu so we need to also add the addition for correction // signed tlu auto rhsCorrectionNode = dagBuilder.add_dot( @@ -403,10 +452,11 @@ struct FunctionToDag { slice(unknownFunction), precision); // tlu(x + y) - tlu(x - y) + auto resultWeights = std::vector{1, 1}; const std::vector subInputs = { lhsTluNode, rhsTluNode}; auto resultNode = dagBuilder.add_levelled_op( - slice(subInputs), lweDimCostFactor, fixedCost, tluSubManp, + slice(subInputs), lweDimCostFactor, fixedCost, slice(resultWeights), slice(resultShape), comment); index[result] = resultNode; @@ -512,35 +562,13 @@ struct FunctionToDag { // 1. (x + y) and (x - y) -> supposing broadcasting is used // to tensorize this operation - - Operation *xOp = innerProductOp.getLhs().getDefiningOp(); - Operation *yOp = innerProductOp.getRhs().getDefiningOp(); - const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; - llvm::APInt xSmanp = llvm::APInt{1, 1, false}; - if (xOp != nullptr) { - const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); - assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); - xSmanp = xSmanpAttr.getValue(); - } - - llvm::APInt ySmanp = llvm::APInt{1, 1, false}; - if (yOp != nullptr) { - const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); - assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); - ySmanp = ySmanpAttr.getValue(); - } - auto loc = loc_to_string(innerProductOp.getLoc()); auto comment = std::string(innerProductOp->getName().getStringRef()) + " " + loc; - // (x + y) and (x - y) - const double addSubManp = - sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); - // tlu(v) const double tluManp = 1; @@ -551,9 +579,10 @@ struct FunctionToDag { const std::vector unknownFunction; // tlu(x + y) - auto addNode = - dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(pairMatrixShape), comment); + auto addWeights = std::vector{1, 1}; + auto addNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, + fixedCost, slice(addWeights), + slice(pairMatrixShape), comment); std::optional lhsCorrectionNode; if (isSignedEint(innerProductOp.getType())) { // If signed mul we need to add the addition node for correction of the @@ -568,9 +597,10 @@ struct FunctionToDag { dagBuilder.add_lut(addNode, slice(unknownFunction), precision); // tlu(x - y) - auto subNode = - dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - addSubManp, slice(pairMatrixShape), comment); + auto subWeights = std::vector{1, 1}; + auto subNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, + fixedCost, slice(subWeights), + slice(pairMatrixShape), comment); // This is a signed tlu so we need to also add the addition for correction // signed tlu auto rhsCorrectionNode = dagBuilder.add_dot( @@ -581,10 +611,11 @@ struct FunctionToDag { slice(unknownFunction), precision); // tlu(x + y) - tlu(x - y) + auto resultWeights = std::vector{1, 1}; const std::vector subInputs = { lhsTluNode, rhsTluNode}; auto resultNode = dagBuilder.add_levelled_op( - slice(subInputs), lweDimCostFactor, fixedCost, tluSubManp, + slice(subInputs), lweDimCostFactor, fixedCost, slice(resultWeights), slice(pairMatrixShape), comment); // 3. Sum(tlu(x + y) - tlu(x - y)) @@ -606,8 +637,9 @@ struct FunctionToDag { // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); + auto weights = std::vector(sumOperands.size(), manp / tluSubManp); index[result] = dagBuilder.add_levelled_op( - slice(sumOperands), lwe_dim_cost_factor, fixed_cost, manp, + slice(sumOperands), lwe_dim_cost_factor, fixed_cost, slice(weights), slice(resultShape), comment); // Create the TFHE.OId attributes @@ -649,46 +681,26 @@ struct FunctionToDag { mlir::Value result = maxOp.getResult(); const std::vector resultShape = getShape(result); - Operation *xOp = maxOp.getX().getDefiningOp(); - Operation *yOp = maxOp.getY().getDefiningOp(); - const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; - llvm::APInt xSmanp = llvm::APInt{1, 1, false}; - if (xOp != nullptr) { - const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); - assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); - xSmanp = xSmanpAttr.getValue(); - } - - llvm::APInt ySmanp = llvm::APInt{1, 1, false}; - if (yOp != nullptr) { - const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); - assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); - ySmanp = ySmanpAttr.getValue(); - } - - const double subManp = - sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); - auto loc = loc_to_string(maxOp.getLoc()); auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc; - auto subNode = - dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - subManp, slice(resultShape), comment); + auto subWeights = std::vector{1, 1}; + auto subNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, + fixedCost, slice(subWeights), + slice(resultShape), comment); - const double tluNodeManp = 1; const std::vector unknownFunction; auto tluNode = dagBuilder.add_lut(subNode, slice(unknownFunction), precision); - const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble()); const std::vector addInputs = { tluNode, inputs[1]}; + auto addWeights = std::vector{1, 1}; auto resultNode = dagBuilder.add_levelled_op( - slice(addInputs), lweDimCostFactor, fixedCost, addManp, + slice(addInputs), lweDimCostFactor, fixedCost, slice(addWeights), slice(resultShape), comment); index[result] = resultNode; @@ -736,9 +748,11 @@ struct FunctionToDag { auto comment = std::string(maxpool2dOp->getName().getStringRef()) + " " + loc; - auto subNode = - dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, - subManp, slice(fakeShape), comment); + auto subWeights = std::vector( + inputs.size(), subManp / sqrt(inputSmanp.roundToDouble())); + auto subNode = dagBuilder.add_levelled_op(slice(inputs), lweDimCostFactor, + fixedCost, slice(subWeights), + slice(fakeShape), comment); const std::vector unknownFunction; auto tluNode = @@ -748,8 +762,10 @@ struct FunctionToDag { const std::vector addInputs = { tluNode, inputs[0]}; + auto resultWeights = std::vector( + addInputs.size(), addManp / sqrt(inputSmanp.roundToDouble())); auto resultNode = dagBuilder.add_levelled_op( - slice(addInputs), lweDimCostFactor, fixedCost, addManp, + slice(addInputs), lweDimCostFactor, fixedCost, slice(resultWeights), slice(resultShape), comment); index[result] = resultNode; // Set attribute on the MLIR node @@ -852,6 +868,14 @@ struct FunctionToDag { return llvm::dyn_cast(op); } + mlir::concretelang::FHE::ZeroNoise asZeroNoise(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + + mlir::concretelang::FHE::AdditiveNoise asAdditiveNoise(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) { return llvm::dyn_cast(op); } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 5cd15adaec..faaaf51f8d 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -794,14 +794,12 @@ class MANPAnalysis std::optional norm2SqEquivFromOp(Operation *op, ArrayRef operands) { std::optional norm2SqEquiv; - if (auto cstNoiseOp = - llvm::dyn_cast(op)) { - if (llvm::isa(op)) { - norm2SqEquiv = llvm::APInt{1, 0, false}; - } else { - norm2SqEquiv = llvm::APInt{1, 1, false}; - } + if (auto zeroNoiseOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = llvm::APInt{1, 0, false}; + } else if (auto cstNoiseOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = llvm::APInt{1, 1, false}; } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = getNoOpSqMANP(operands); diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index a3a52c3ed4..e3243a3031 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -649,10 +649,11 @@ impl<'dag> DagBuilder<'dag> { inputs: &[ffi::OperatorIndex], lwe_dim_cost_factor: f64, fixed_cost: f64, - manp: f64, + weights: &[f64], out_shape: &[u64], comment: &str, ) -> ffi::OperatorIndex { + debug_assert!(weights.len() == inputs.len()); let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); let out_shape = Shape { @@ -665,7 +666,7 @@ impl<'dag> DagBuilder<'dag> { }; self.0 - .add_levelled_op(inputs, complexity, manp, out_shape, comment) + .add_levelled_op(inputs, complexity, weights, out_shape, comment) .into() } @@ -781,7 +782,7 @@ mod ffi { inputs: &[OperatorIndex], lwe_dim_cost_factor: f64, fixed_cost: f64, - manp: f64, + weights: &[f64], out_shape: &[u64], comment: &str, ) -> OperatorIndex; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index f20c22d655..3e6807170c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -996,7 +996,7 @@ struct DagBuilder final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept; ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; - ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; ::concrete_optimizer::dag::OperatorIndex add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; @@ -1309,7 +1309,7 @@ ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilde ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_dot(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::concrete_optimizer::Weights *weights) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(::concrete_optimizer::DagBuilder &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$DagBuilder$add_round_op(::concrete_optimizer::DagBuilder &self, ::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; @@ -1427,8 +1427,8 @@ ::concrete_optimizer::dag::OperatorIndex DagBuilder::add_dot(::rust::Slice<::con return concrete_optimizer$cxxbridge1$DagBuilder$add_dot(*this, inputs, weights.into_raw()); } -::concrete_optimizer::dag::OperatorIndex DagBuilder::add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept { - return concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(*this, inputs, lwe_dim_cost_factor, fixed_cost, manp, out_shape, comment); +::concrete_optimizer::dag::OperatorIndex DagBuilder::add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept { + return concrete_optimizer$cxxbridge1$DagBuilder$add_levelled_op(*this, inputs, lwe_dim_cost_factor, fixed_cost, weights, out_shape, comment); } ::concrete_optimizer::dag::OperatorIndex DagBuilder::add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index abd0685d65..375b4e2e8f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -977,7 +977,7 @@ struct DagBuilder final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice<::std::uint64_t const> out_shape) noexcept; ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice<::std::uint64_t const> table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; - ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> inputs, double lwe_dim_cost_factor, double fixed_cost, ::rust::Slice weights, ::rust::Slice<::std::uint64_t const> out_shape, ::rust::Str comment) noexcept; ::concrete_optimizer::dag::OperatorIndex add_round_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept; void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index 16eaab9a41..f3fce4a406 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -89,7 +89,7 @@ pub enum Operator { LevelledOp { inputs: Vec, complexity: LevelledComplexity, - manp: f64, + weights: Vec, out_shape: Shape, comment: String, }, @@ -171,7 +171,7 @@ impl fmt::Display for Operator { } Self::LevelledOp { inputs, - manp, + weights, out_shape, .. } => { @@ -182,7 +182,7 @@ impl fmt::Display for Operator { } write!(f, "%{}", input.0)?; } - write!(f, "] : manp={manp} x {out_shape:?}")?; + write!(f, "] : weights={weights:?}, out_shape={out_shape:?}")?; } Self::Round { input, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index cffabfff32..5bbe1520f8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -194,17 +194,19 @@ impl<'dag> DagBuilder<'dag> { &mut self, inputs: impl Into>, complexity: LevelledComplexity, - manp: f64, + weights: impl Into>, out_shape: impl Into, comment: impl Into, ) -> OperatorIndex { let inputs = inputs.into(); let out_shape = out_shape.into(); let comment = comment.into(); + let weights = weights.into(); + assert_eq!(weights.len(), inputs.len()); let op = Operator::LevelledOp { inputs, complexity, - manp, + weights, out_shape, comment, }; @@ -532,12 +534,12 @@ impl Dag { &mut self, inputs: impl Into>, complexity: LevelledComplexity, - manp: f64, + weights: impl Into>, out_shape: impl Into, comment: impl Into, ) -> OperatorIndex { self.builder(DEFAULT_CIRCUIT) - .add_levelled_op(inputs, complexity, manp, out_shape, comment) + .add_levelled_op(inputs, complexity, weights, out_shape, comment) } pub fn add_unsafe_cast( @@ -797,12 +799,23 @@ mod tests { let input2 = builder.add_input(2, Shape::number()); let cpx_add = LevelledComplexity::ADDITION; - let sum1 = builder.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); + let sum1 = builder.add_levelled_op( + [input1, input2], + cpx_add, + [1.0, 1.0], + Shape::number(), + "sum", + ); let lut1 = builder.add_lut(sum1, FunctionTable::UNKWOWN, 1); - let concat = - builder.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat"); + let concat = builder.add_levelled_op( + [input1, lut1], + cpx_add, + [1.0, 1.0], + Shape::vector(2), + "concat", + ); let dot = builder.add_dot([concat], [1, 2]); @@ -827,7 +840,7 @@ mod tests { Operator::LevelledOp { inputs: vec![input1, input2], complexity: cpx_add, - manp: 1.0, + weights: vec![1.0, 1.0], out_shape: Shape::number(), comment: "sum".to_string(), }, @@ -839,7 +852,7 @@ mod tests { Operator::LevelledOp { inputs: vec![input1, lut1], complexity: cpx_add, - manp: 1.0, + weights: vec![1.0, 1.0], out_shape: Shape::vector(2), comment: "concat".to_string(), }, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 9d9d5207b7..26c35ff82e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -20,7 +20,7 @@ use super::partitions::Partitions; use super::variance_constraint::VarianceConstraint; use crate::utils::square; -const MAX_FORWARDING: u16 = 1000; +const MAX_FORWARDING: u16 = 500; #[derive(Debug, Clone)] pub struct PartitionedDag { @@ -194,7 +194,6 @@ impl VariancedDag { if operator.operator().is_input() { continue; } - let max_var = |acc: SymbolicVariance, input: SymbolicVariance| acc.max(&input); // Operator variance will be used to override the noise let mut operator_variance = OperatorVariance::nan(nb_partitions); // We first compute the noise in the partition of the operator @@ -207,14 +206,13 @@ impl VariancedDag { nb_partitions, operator.partition().instruction_partition, ), - Operator::LevelledOp { manp, .. } => { - let max_var = operator - .get_inputs_iter() - .map(|a| a.variance()[operator.partition().instruction_partition].clone()) - .reduce(max_var) - .unwrap(); - max_var.after_levelled_op(*manp) - } + Operator::LevelledOp { weights, .. } => operator + .get_inputs_iter() + .zip(weights) + .fold(SymbolicVariance::ZERO, |acc, (inp, &weight)| { + acc + inp.variance()[operator.partition().instruction_partition].clone() + * square(weight) + }), Operator::Dot { kind: DotKind::CompatibleTensor { .. }, .. @@ -758,6 +756,23 @@ pub mod tests { } } + #[test] + #[should_panic(expected = "Forwarding of noise did not reach a fixed point.")] + fn test_decreasing_panics() { + let mut dag = unparametrized::Dag::new(); + let inp = dag.add_input(1, Shape::number()); + let oup = dag.add_levelled_op( + [inp], + LevelledComplexity::ZERO, + [0.5], + Shape::number(), + "comment", + ); + dag.add_composition(oup, inp); + let p_cut = PartitionCut::for_each_precision(&dag); + let _ = super::analyze(&dag, &CONFIG, &Some(p_cut), LOW_PRECISION_PARTITION).unwrap(); + } + #[test] fn test_composition_with_nongrowing_inputs_only() { let mut dag = unparametrized::Dag::new(); @@ -765,7 +780,7 @@ pub mod tests { let oup = dag.add_levelled_op( [inp], LevelledComplexity::ZERO, - 1.0, + [1.0], Shape::number(), "comment", ); @@ -789,7 +804,7 @@ pub mod tests { let oup = dag.add_levelled_op( [inp], LevelledComplexity::ZERO, - 1.1, + [1.1], Shape::number(), "comment", ); @@ -952,7 +967,7 @@ pub mod tests { let _levelled = dag.add_levelled_op( [lut1, input2], LevelledComplexity::ZERO, - manp, + [manp, manp], &out_shape, "comment", ); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index 1df078e1a3..e82683ca66 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -436,7 +436,7 @@ fn optimize_sign_extract() { let input1 = dag.add_levelled_op( [small_input1], complexity, - 1.0, + [1.0], Shape::vector(1_000_000), "comment", ); @@ -587,7 +587,7 @@ fn test_chained_partitions_non_feasible_single_params() { lut_input = dag.add_levelled_op( [lut_input], LevelledComplexity::ZERO, - noise_factor, + [noise_factor], Shape::number(), "", ); @@ -844,8 +844,8 @@ fn test_bug_with_zero_noise() { let out_shape = Shape::number(); let mut dag = unparametrized::Dag::new(); let v0 = dag.add_input(2, &out_shape); - let v1 = dag.add_levelled_op([v0], complexity, 0.0, &out_shape, "comment"); - let v2 = dag.add_levelled_op([v1], complexity, 1.0, &out_shape, "comment"); + let v1 = dag.add_levelled_op([v0], complexity, [0.0], &out_shape, "comment"); + let v2 = dag.add_levelled_op([v1], complexity, [1.0], &out_shape, "comment"); let v3 = dag.add_unsafe_cast(v2, 1); let _ = dag.add_lut(v3, FunctionTable { values: vec![] }, 1); let sol = optimize(&dag, &None, PartitionIndex(0)); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs index 4bf70f5b81..664088bda2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -144,38 +144,6 @@ impl SymbolicVariance { new } - #[allow(clippy::float_cmp)] - pub fn after_levelled_op(&self, manp: f64) -> Self { - let new_coeff = manp * manp; - // detect the previous base manp level - // this is the maximum value of fresh base noise and pbs base noise - let mut current_max: f64 = 0.0; - for partition in PartitionIndex::range(0, self.nb_partitions()) { - let fresh_coeff = self.coeff_input(partition); - let pbs_noise_coeff = self.coeff_pbs(partition); - current_max = current_max.max(fresh_coeff).max(pbs_noise_coeff); - } - // assert!(1.0 <= current_max); - // assert!( - // current_max <= new_coeff, - // "Non monotonious levelled op: {current_max} <= {new_coeff}" - // ); - // replace all current_max by new_coeff - // multiply everything else by new_coeff / current_max - let mut new = self.clone(); - if current_max == 0.0 { - return new; - } - for cell in &mut new.coeffs.values { - if *cell == current_max { - *cell = new_coeff; - } else { - *cell *= new_coeff / current_max; - } - } - new - } - pub fn max(&self, other: &Self) -> Self { let mut coeffs = self.coeffs.clone(); for (i, coeff) in coeffs.iter_mut().enumerate() { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 84502b857a..457ba52545 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -1,4 +1,4 @@ -use super::symbolic_variance::{SymbolicVariance, VarianceOrigin}; +use super::symbolic_variance::SymbolicVariance; use crate::dag::operator::{ DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, }; @@ -100,17 +100,6 @@ fn assert_properties_correctness(dag: &SoloKeyDag) { assert_valid_variances(dag); } -fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance]) -> VarianceOrigin { - let first_origin = first(inputs, out_variances).origin(); - for input in inputs.iter().skip(1) { - let item = &out_variances[input.0]; - if first_origin != item.origin() { - return VarianceOrigin::Mixed; - } - } - first_origin -} - #[derive(Clone, Debug)] pub struct SoloKeyDag { // Collect all operators output variances @@ -146,15 +135,15 @@ fn out_variance( match op { Operator::Input { .. } => SymbolicVariance::INPUT, Operator::Lut { .. } => SymbolicVariance::LUT, - Operator::LevelledOp { inputs, manp, .. } => { - let variance_factor = SymbolicVariance::manp_to_variance_factor(*manp); - let origin = match variance_origin(inputs, out_variances) { - VarianceOrigin::Input => SymbolicVariance::INPUT, - VarianceOrigin::Lut | VarianceOrigin::Mixed /* Mixed: assume the worst */ - => SymbolicVariance::LUT - }; - origin * variance_factor - } + Operator::LevelledOp { + inputs, weights, .. + } => inputs + .iter() + .map(|i| out_variances[i.0]) + .zip(weights) + .fold(SymbolicVariance::ZERO, |acc, (var, &weight)| { + acc + var * square(weight) + }), Operator::Dot { kind: DotKind::CompatibleTensor { .. }, .. @@ -702,14 +691,13 @@ pub mod tests { let cpx_dot = LevelledComplexity::ADDITION; let weights = Weights::vector([1, 2]); #[allow(clippy::imprecise_flops)] - let manp = (1.0 * 1.0 + 2.0 * 2_f64).sqrt(); - let dot = graph.add_levelled_op([input1, input1], cpx_dot, manp, Shape::number(), "dot"); + let dot = + graph.add_levelled_op([input1, input1], cpx_dot, [1., 2.], Shape::number(), "dot"); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); - assert!(analysis.out_variances[dot.0].origin() == VarianceOrigin::Input); assert_eq!(graph.out_precisions[dot.0], 3); let expected_square_norm2 = weights.square_norm2() as f64; let actual_square_norm2 = analysis.out_variances[dot.0].input_coeff; @@ -720,7 +708,6 @@ pub mod tests { let constraint = analysis.constraint(); assert!(constraint.pareto_in_lut.is_empty()); assert!(constraint.pareto_output.len() == 1); - assert_eq!(constraint.pareto_output[0].origin(), VarianceOrigin::Input); assert_f64_eq(constraint.pareto_output[0].input_coeff, 5.0); } @@ -763,10 +750,8 @@ pub mod tests { assert_f64_eq(expected_cost, complexity_cost); let constraint = analysis.constraint(); assert_eq!(constraint.pareto_output.len(), 1); - assert_eq!(constraint.pareto_output[0].origin(), VarianceOrigin::Lut); assert_f64_eq(constraint.pareto_output[0].lut_coeff, 1.0); assert_eq!(constraint.pareto_in_lut.len(), 1); - assert_eq!(constraint.pareto_in_lut[0].origin(), VarianceOrigin::Lut); assert_f64_eq( constraint.pareto_in_lut[0].lut_coeff, weights.square_norm2() as f64, @@ -796,7 +781,6 @@ pub mod tests { assert_eq!(constraint.pareto_output.len(), 1); assert_eq!(constraint.pareto_output[0], SymbolicVariance::LUT); assert_eq!(constraint.pareto_in_lut.len(), 1); - assert_eq!(constraint.pareto_in_lut[0].origin(), VarianceOrigin::Mixed); assert_eq!(constraint.pareto_in_lut[0], expected_mixed); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 4d4ae204a8..e17765c055 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -385,16 +385,15 @@ pub fn optimize( pub fn add_v0_dag(dag: &mut Dag, sum_size: u64, precision: u64, noise_factor: f64) { use crate::dag::operator::{FunctionTable, Shape}; - let same_scale_manp = 1.0; let manp = noise_factor; let out_shape = &Shape::number(); let complexity = LevelledComplexity::ADDITION * sum_size; let comment = "dot"; let precision = precision as Precision; let input1 = dag.add_input(precision, out_shape); - let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment); + let dot1 = dag.add_levelled_op([input1], complexity, [1.0], out_shape, comment); let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); - let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment); + let dot2 = dag.add_levelled_op([lut1], complexity, [manp], out_shape, comment); let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } @@ -432,7 +431,6 @@ pub(crate) mod tests { use crate::dag::operator::{FunctionTable, Shape, Weights}; use crate::noise_estimator::p_error::repeat_p_error; use crate::optimization::config::SearchSpace; - use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin; use crate::optimization::{atomic_pattern, decomposition}; use crate::utils::square; @@ -607,10 +605,8 @@ pub(crate) mod tests { let constraint = dag2.constraint(); assert_eq!(constraint.pareto_output.len(), 1); assert_eq!(constraint.pareto_in_lut.len(), 1); - assert_eq!(constraint.pareto_output[0].origin(), VarianceOrigin::Lut); assert_f64_eq(1.0, constraint.pareto_output[0].lut_coeff); assert!(constraint.pareto_in_lut.len() == 1); - assert_eq!(constraint.pareto_in_lut[0].origin(), VarianceOrigin::Lut); assert_f64_eq(square(weight) as f64, constraint.pareto_in_lut[0].lut_coeff); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs index e2fc192b29..b81a11d89d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs @@ -25,13 +25,6 @@ pub struct SymbolicVariance { // see pareto sorting and dominate_or_equal } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum VarianceOrigin { - Input, - Lut, - Mixed, -} - impl std::ops::Add for SymbolicVariance { type Output = Self; @@ -93,20 +86,6 @@ impl SymbolicVariance { lut_coeff: 1.0, }; - pub fn origin(&self) -> VarianceOrigin { - if self.lut_coeff == 0.0 { - VarianceOrigin::Input - } else if self.input_coeff == 0.0 { - VarianceOrigin::Lut - } else { - VarianceOrigin::Mixed - } - } - - pub fn manp_to_variance_factor(manp: f64) -> f64 { - manp * manp - } - pub fn dominate_or_equal(&self, other: &Self) -> bool { let extra_other_minimal_base_noise = 0.0_f64.max(other.input_coeff - self.input_coeff); other.lut_coeff + extra_other_minimal_base_noise <= self.lut_coeff diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs index 0ac2dc8e9d..fe7b8bf04d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs @@ -97,8 +97,8 @@ impl<'dag> Viz for crate::dag::unparametrized::DagOperator<'dag> { Operator::Dot { .. } => { format!("{index} [label = \"{{%{index} = Dot({input_string})}}\" fillcolor={color}];") } - Operator::LevelledOp { manp, .. } => { - format!("{index} [label = \"{{%{index} = LevelledOp({input_string}) |{{manp:|{manp:?}}}}}\" fillcolor={color}];") + Operator::LevelledOp { weights, .. } => { + format!("{index} [label = \"{{%{index} = LevelledOp({input_string}) |{{weights:|{weights:?}}}}}\" fillcolor={color}];") } Operator::UnsafeCast { out_precision, .. } => format!( "{index} [label = \"{{%{index} = UnsafeCast({input_string}) |{{out_precision:|{out_precision:?}}}}}\" fillcolor={color}];" diff --git a/docs/explanations/FHEDialect.md b/docs/explanations/FHEDialect.md index 4797da2694..ff76c31cbb 100644 --- a/docs/explanations/FHEDialect.md +++ b/docs/explanations/FHEDialect.md @@ -29,7 +29,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: AdditiveNoise, Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -67,7 +67,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: BinaryEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: AdditiveNoise, BinaryEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -543,7 +543,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint +Interfaces: AdditiveNoise, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UnaryEint Effects: MemoryEffects::Effect{} @@ -660,7 +660,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: AdditiveNoise, Binary, BinaryEintInt, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -698,7 +698,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: BinaryEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: AdditiveNoise, BinaryEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -736,7 +736,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: Binary, BinaryIntEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: AdditiveNoise, Binary, BinaryIntEint, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} @@ -870,7 +870,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, ConstantNoise, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ZeroNoise Effects: MemoryEffects::Effect{} @@ -894,7 +894,7 @@ Example: Traits: AlwaysSpeculatableImplTrait -Interfaces: ConditionallySpeculatable, ConstantNoise, NoMemoryEffect (MemoryEffectOpInterface) +Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ZeroNoise Effects: MemoryEffects::Effect{}