From c78c37636bcb1fe9af83cd829c3c4368be442b94 Mon Sep 17 00:00:00 2001 From: Louis Tremblay Thibault Date: Tue, 27 Jun 2023 16:29:14 +0200 Subject: [PATCH] Add scalar ops, bitshift, min/max (#130) * feat(tfhe): add support for casting * feat: add `cast` precompile * feat(tfhe): add `&`, `|`, `^`, `==`, `>`, `>=` * feat(precompiles): add missing ops * feat(precompiles): add tests * fix: add precompile contract addresses * fit(precompiles): change precompile order * feat(tfhe): add scalar ops * fix(tfhe): add `isValid` for ciphertext types * fix(tfhe): check type validity * feat(contracts): add fhe operators * fix(test): restore deserialize failure test --- core/vm/contracts.go | 1490 +++++++++++++++++++++++++++++++------ core/vm/contracts_test.go | 936 ++++++++++++++++++++--- core/vm/tfhe.go | 1325 +++++++++++++++++++++++++++++---- core/vm/tfhe_test.go | 930 ++++++++++++++++++++--- params/protocol_params.go | 15 +- 5 files changed, 4122 insertions(+), 574 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index b9d144d3d..2cd30e67e 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -71,7 +71,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{67}): &reencrypt{}, common.BytesToAddress([]byte{68}): &fhePubKey{}, common.BytesToAddress([]byte{69}): &require{}, - common.BytesToAddress([]byte{70}): &fheLte{}, + common.BytesToAddress([]byte{70}): &fheLe{}, common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, @@ -85,6 +85,13 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{81}): &fheEq{}, common.BytesToAddress([]byte{82}): &fheGe{}, common.BytesToAddress([]byte{83}): &fheGt{}, + common.BytesToAddress([]byte{84}): &fheShl{}, + common.BytesToAddress([]byte{85}): &fheShr{}, + common.BytesToAddress([]byte{86}): &fheNe{}, + common.BytesToAddress([]byte{87}): &fheMin{}, + common.BytesToAddress([]byte{88}): &fheMax{}, + common.BytesToAddress([]byte{89}): &fheNeg{}, + common.BytesToAddress([]byte{90}): &fheNot{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -106,7 +113,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{67}): &reencrypt{}, common.BytesToAddress([]byte{68}): &fhePubKey{}, common.BytesToAddress([]byte{69}): &require{}, - common.BytesToAddress([]byte{70}): &fheLte{}, + common.BytesToAddress([]byte{70}): &fheLe{}, common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, @@ -120,6 +127,13 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{81}): &fheEq{}, common.BytesToAddress([]byte{82}): &fheGe{}, common.BytesToAddress([]byte{83}): &fheGt{}, + common.BytesToAddress([]byte{84}): &fheShl{}, + common.BytesToAddress([]byte{85}): &fheShr{}, + common.BytesToAddress([]byte{86}): &fheNe{}, + common.BytesToAddress([]byte{87}): &fheMin{}, + common.BytesToAddress([]byte{88}): &fheMax{}, + common.BytesToAddress([]byte{89}): &fheNeg{}, + common.BytesToAddress([]byte{90}): &fheNot{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -142,7 +156,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{67}): &reencrypt{}, common.BytesToAddress([]byte{68}): &fhePubKey{}, common.BytesToAddress([]byte{69}): &require{}, - common.BytesToAddress([]byte{70}): &fheLte{}, + common.BytesToAddress([]byte{70}): &fheLe{}, common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, @@ -156,6 +170,13 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{81}): &fheEq{}, common.BytesToAddress([]byte{82}): &fheGe{}, common.BytesToAddress([]byte{83}): &fheGt{}, + common.BytesToAddress([]byte{84}): &fheShl{}, + common.BytesToAddress([]byte{85}): &fheShr{}, + common.BytesToAddress([]byte{86}): &fheNe{}, + common.BytesToAddress([]byte{87}): &fheMin{}, + common.BytesToAddress([]byte{88}): &fheMax{}, + common.BytesToAddress([]byte{89}): &fheNeg{}, + common.BytesToAddress([]byte{90}): &fheNot{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -178,7 +199,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{67}): &reencrypt{}, common.BytesToAddress([]byte{68}): &fhePubKey{}, common.BytesToAddress([]byte{69}): &require{}, - common.BytesToAddress([]byte{70}): &fheLte{}, + common.BytesToAddress([]byte{70}): &fheLe{}, common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, @@ -192,6 +213,13 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{81}): &fheEq{}, common.BytesToAddress([]byte{82}): &fheGe{}, common.BytesToAddress([]byte{83}): &fheGt{}, + common.BytesToAddress([]byte{84}): &fheShl{}, + common.BytesToAddress([]byte{85}): &fheShr{}, + common.BytesToAddress([]byte{86}): &fheNe{}, + common.BytesToAddress([]byte{87}): &fheMin{}, + common.BytesToAddress([]byte{88}): &fheMax{}, + common.BytesToAddress([]byte{89}): &fheNeg{}, + common.BytesToAddress([]byte{90}): &fheNot{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -214,7 +242,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{67}): &reencrypt{}, common.BytesToAddress([]byte{68}): &fhePubKey{}, common.BytesToAddress([]byte{69}): &require{}, - common.BytesToAddress([]byte{70}): &fheLte{}, + common.BytesToAddress([]byte{70}): &fheLe{}, common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, @@ -228,6 +256,13 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{81}): &fheEq{}, common.BytesToAddress([]byte{82}): &fheGe{}, common.BytesToAddress([]byte{83}): &fheGt{}, + common.BytesToAddress([]byte{84}): &fheShl{}, + common.BytesToAddress([]byte{85}): &fheShr{}, + common.BytesToAddress([]byte{86}): &fheNe{}, + common.BytesToAddress([]byte{87}): &fheMin{}, + common.BytesToAddress([]byte{88}): &fheMax{}, + common.BytesToAddress([]byte{89}): &fheNeg{}, + common.BytesToAddress([]byte{90}): &fheNot{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -1319,8 +1354,8 @@ func importRandomCiphertext(accessibleState PrecompileAccessibleState, t fheUint } func get2VerifiedOperands(accessibleState PrecompileAccessibleState, input []byte) (lhs *verifiedCiphertext, rhs *verifiedCiphertext, err error) { - if len(input) != 64 { - return nil, nil, errors.New("input needs to contain two 256-bit sized values") + if len(input) != 65 { + return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") } lhs = getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) if lhs == nil { @@ -1334,6 +1369,27 @@ func get2VerifiedOperands(accessibleState PrecompileAccessibleState, input []byt return } +func getScalarOperands(accessibleState PrecompileAccessibleState, input []byte) (lhs *verifiedCiphertext, rhs *big.Int, err error) { + if len(input) != 65 { + return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") + } + lhs = getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) + if lhs == nil { + return nil, nil, errors.New("unverified ciphertext handle") + } + rhs = &big.Int{} + rhs.SetBytes(input[32:64]) + return +} + +func isScalarOp(accessibleState PrecompileAccessibleState, input []byte) (bool, error) { + if len(input) != 65 { + return false, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") + } + isScalar := (input[64] == 1) + return isScalar, nil +} + var fheAddSubGasCosts = map[fheUintType]uint64{ FheUint8: params.FheUint8AddSubGas, FheUint16: params.FheUint16AddSubGas, @@ -1352,10 +1408,28 @@ var fheMulGasCosts = map[fheUintType]uint64{ FheUint32: params.FheUint32MulGas, } -var fheLteGasCosts = map[fheUintType]uint64{ - FheUint8: params.FheUint8LteGas, - FheUint16: params.FheUint16LteGas, - FheUint32: params.FheUint32LteGas, +var fheShiftGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8ShiftGas, + FheUint16: params.FheUint16ShiftGas, + FheUint32: params.FheUint32ShiftGas, +} + +var fheLeGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8LeGas, + FheUint16: params.FheUint16LeGas, + FheUint32: params.FheUint32LeGas, +} + +var fheMinMaxGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8MinMaxGas, + FheUint16: params.FheUint16MinMaxGas, + FheUint32: params.FheUint32MinMaxGas, +} + +var fheNegNotGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8NegNotGas, + FheUint16: params.FheUint16NegNotGas, + FheUint32: params.FheUint32NegNotGas, } var fheReencryptGasCosts = map[fheUintType]uint64{ @@ -1386,54 +1460,106 @@ type fheAdd struct{} func (e *fheAdd) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheAdd/Sub RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input)) + logger.Error("fheAdd/Sub RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return 0 } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - logger.Error("fheAdd/Sub RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return 0 + var lhs *verifiedCiphertext + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheAdd/Sub RequiredGas() ciphertext inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + logger.Error("fheAdd/Sub RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return 0 + } + } else { + lhs, _, err = getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheAdd/Sub RequiredGas() scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } } return fheAddSubGasCosts[lhs.ciphertext.fheUintType] } func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheAdd inputs not verified", "err", err, "input", hex.EncodeToString(input)) + logger.Error("fheAdd can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheAdd operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return nil, errors.New(msg) - } + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheAdd inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheAdd operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil - } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } - result, err := lhs.ciphertext.add(rhs.ciphertext) - if err != nil { - logger.Error("fheAdd failed", "err", err) - return nil, err - } - importCiphertext(accessibleState, result) + result, err := lhs.ciphertext.add(rhs.ciphertext) + if err != nil { + logger.Error("fheAdd failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) - // TODO: for testing - err = os.WriteFile("/tmp/add_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheAdd failed to write /tmp/add_result", "err", err) - return nil, err - } + // TODO: for testing + err = os.WriteFile("/tmp/add_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheAdd failed to write /tmp/add_result", "err", err) + return nil, err + } - resultHash := result.getHash() - logger.Info("fheAdd success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) - return resultHash[:], nil + resultHash := result.getHash() + logger.Info("fheAdd success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheAdd scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarAdd(rhs.Uint64()) + if err != nil { + logger.Error("fheAdd failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/add_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheAdd scalar failed to write /tmp/add_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheAdd scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } } func classicalPublicKeyEncrypt(value *big.Int, userPublicKey []byte) ([]byte, error) { @@ -1486,7 +1612,10 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller ctBytes := input[:len(input)-1] ctType := fheUintType(input[len(input)-1]) - + if !ctType.isValid() { + logger.Error("invalid type to cast to") + return nil, errors.New("invalid type provided") + } // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { return importRandomCiphertext(accessibleState, ctType), nil @@ -1795,58 +1924,110 @@ func (e *optimisticRequire) Run(accessibleState PrecompileAccessibleState, calle return nil, nil } -type fheLte struct{} +type fheLe struct{} -func (e *fheLte) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) +func (e *fheLe) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + logger := accessibleState.Interpreter().evm.Logger + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - accessibleState.Interpreter().evm.Logger.Error("fheLte (comparison) RequiredGas() inputs not verified", "err", err) + logger.Error("comparison RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return 0 } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - accessibleState.Interpreter().evm.Logger.Error("fheLte (comparison) RequiredGas() operand type mismatch", "lhs", - lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return 0 + var lhs *verifiedCiphertext + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("comparison RequiredGas() ciphertext inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + logger.Error("comparison RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return 0 + } + } else { + lhs, _, err = getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("comparison RequiredGas() scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } } - return fheLteGasCosts[lhs.ciphertext.fheUintType] + return fheLeGasCosts[lhs.ciphertext.fheUintType] } -func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func (e *fheLe) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheLte inputs not verified", "err", err) + logger.Error("fheLe can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheLte operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return nil, errors.New(msg) - } + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheLe inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheLe operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil - } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } - result, err := lhs.ciphertext.lte(rhs.ciphertext) - if err != nil { - logger.Error("fheLte failed", "err", err) - return nil, err - } - importCiphertext(accessibleState, result) + result, err := lhs.ciphertext.le(rhs.ciphertext) + if err != nil { + logger.Error("fheLe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) - // TODO: for testing - err = os.WriteFile("/tmp/lte_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheAdd failed to write /tmp/lte_result", "err", err) - return nil, err - } + // TODO: for testing + err = os.WriteFile("/tmp/le_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheLe failed to write /tmp/le_result", "err", err) + return nil, err + } - resultHash := result.getHash() - logger.Info("fheLte success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) - return resultHash[:], nil + resultHash := result.getHash() + logger.Info("fheLe success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheLe scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarLe(rhs.Uint64()) + if err != nil { + logger.Error("fheLe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/le_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheLe scalar failed to write /tmp/le_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheLe scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } } type fheSub struct{} @@ -1859,99 +2040,203 @@ func (e *fheSub) RequiredGas(accessibleState PrecompileAccessibleState, input [] func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheSub inputs not verified", "err", err) + logger.Error("fheSub can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheSub operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return nil, errors.New(msg) - } + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheSub inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheSub operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil - } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } - result, err := lhs.ciphertext.sub(rhs.ciphertext) - if err != nil { - logger.Error("fheSub failed", "err", err) - return nil, err - } - importCiphertext(accessibleState, result) + result, err := lhs.ciphertext.sub(rhs.ciphertext) + if err != nil { + logger.Error("fheSub failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) - // TODO: for testing - err = os.WriteFile("/tmp/sub_result", result.serialize(), 0644) - if err != nil { - logger.Error("fheSub failed to write /tmp/sub_result", "err", err) - return nil, err - } + // TODO: for testing + err = os.WriteFile("/tmp/sub_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheSub failed to write /tmp/sub_result", "err", err) + return nil, err + } - resultHash := result.getHash() - logger.Info("fheSub success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) - return resultHash[:], nil + resultHash := result.getHash() + logger.Info("fheSub success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheSub scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarSub(rhs.Uint64()) + if err != nil { + logger.Error("fheSub failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/sub_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheSub scalar failed to write /tmp/sub_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheSub scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } } type fheMul struct{} func (e *fheMul) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + logger := accessibleState.Interpreter().evm.Logger + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - accessibleState.Interpreter().evm.Logger.Error("fheMul RequiredGas() inputs not verified", "err", err) + logger.Error("fheMul RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return 0 } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - accessibleState.Interpreter().evm.Logger.Error("fheMul RequiredGas() operand type mismatch", "lhs", - lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return 0 + var lhs *verifiedCiphertext + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheMul RequiredGas() ciphertext inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + logger.Error("fheMul RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return 0 + } + } else { + lhs, _, err = getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheMul RequiredGas() scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } } return fheMulGasCosts[lhs.ciphertext.fheUintType] } func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - accessibleState.Interpreter().evm.Logger.Error("fheMul inputs not verified", "err", err) + logger.Error("fheMul can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheMul operand type mismatch" - accessibleState.Interpreter().evm.Logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return nil, errors.New(msg) - } + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheMul inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheMul operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil - } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } - result, err := lhs.ciphertext.mul(rhs.ciphertext) - if err != nil { - accessibleState.Interpreter().evm.Logger.Error("fheMul failed", "err", err) - return nil, err - } - importCiphertext(accessibleState, result) + result, err := lhs.ciphertext.mul(rhs.ciphertext) + if err != nil { + logger.Error("fheMul failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) - // TODO: for testing - err = os.WriteFile("/tmp/mul_result", result.serialize(), 0644) - if err != nil { - accessibleState.Interpreter().evm.Logger.Error("fheMul failed to write /tmp/mul_result", "err", err) - return nil, err - } + // TODO: for testing + err = os.WriteFile("/tmp/mul_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheMul failed to write /tmp/mul_result", "err", err) + return nil, err + } - ctHash := result.getHash() + resultHash := result.getHash() + logger.Info("fheMul success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil - return ctHash[:], nil + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheMul scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarMul(rhs.Uint64()) + if err != nil { + logger.Error("fheMul failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/mul_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheMul scalar failed to write /tmp/mul_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheMul scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } } type fheBitAnd struct{} func (e *fheBitAnd) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("Bitwise op RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + + if isScalar { + msg := "Bitwise op RequiredGas() scalar op not supported" + logger.Error(msg) + return 0 + } + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) if err != nil { logger.Error("Bitwise op RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input)) @@ -1966,15 +2251,28 @@ func (e *fheBitAnd) RequiredGas(accessibleState PrecompileAccessibleState, input func (e *fheBitAnd) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheBitAnd inputs not verified", "err", err) + logger.Error("fheBitAnd can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheBitAnd operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + if isScalar { + msg := "fheBitAnd scalar op not supported" + logger.Error(msg) + return nil, errors.New(msg) + } + + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheBitAnd inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheBitAnd operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) return nil, errors.New(msg) } @@ -2012,6 +2310,19 @@ func (e *fheBitOr) RequiredGas(accessibleState PrecompileAccessibleState, input func (e *fheBitOr) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("fheBitOr can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + if isScalar { + msg := "fheBitOr scalar op not supported" + logger.Error(msg) + return nil, errors.New(msg) + } + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) if err != nil { logger.Error("fheBitOr inputs not verified", "err", err) @@ -2058,6 +2369,19 @@ func (e *fheBitXor) RequiredGas(accessibleState PrecompileAccessibleState, input func (e *fheBitXor) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("fheBitXor can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + if isScalar { + msg := "fheBitXor scalar op not supported" + logger.Error(msg) + return nil, errors.New(msg) + } + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) if err != nil { logger.Error("fheBitXor inputs not verified", "err", err) @@ -2094,187 +2418,909 @@ func (e *fheBitXor) Run(accessibleState PrecompileAccessibleState, caller common return resultHash[:], nil } -type fheEq struct{} +type fheShl struct{} -func (e *fheEq) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { - // Implement in terms of lte, because comparison costs are currently the same. - lte := fheLte{} - return lte.RequiredGas(accessibleState, input) +func (e *fheShl) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + logger := accessibleState.Interpreter().evm.Logger + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("fheShift RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + var lhs *verifiedCiphertext + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheShift RequiredGas() ciphertext inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + logger.Error("fheShift RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return 0 + } + } else { + lhs, _, err = getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheShift RequiredGas() scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + } + return fheShiftGasCosts[lhs.ciphertext.fheUintType] } -func (e *fheEq) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func (e *fheShl) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheEq inputs not verified", "err", err) + logger.Error("fheShl can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheEq operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return nil, errors.New(msg) + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheShl inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheShl operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.shl(rhs.ciphertext) + if err != nil { + logger.Error("fheShl failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/shl_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheShl failed to write /tmp/shl_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheShl success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheShl scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarShl(rhs.Uint64()) + if err != nil { + logger.Error("fheShl failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/shl_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheShl scalar failed to write /tmp/shl_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheShl scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil } +} - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil +type fheShr struct{} + +func (e *fheShr) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of shl, because comparison costs are currently the same. + shl := fheShl{} + return shl.RequiredGas(accessibleState, input) +} + +func (e *fheShr) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("fheShr can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return nil, err } - result, err := lhs.ciphertext.eq(rhs.ciphertext) + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheShr inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheShr operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.shr(rhs.ciphertext) + if err != nil { + logger.Error("fheShr failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/shr_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheShr failed to write /tmp/shr_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheShr success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheShr scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarShr(rhs.Uint64()) + if err != nil { + logger.Error("fheShr failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/shr_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheShr scalar failed to write /tmp/shr_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheShr scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } +} + +type fheEq struct{} + +func (e *fheEq) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of le, because comparison costs are currently the same. + le := fheLe{} + return le.RequiredGas(accessibleState, input) +} + +func (e *fheEq) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheEq failed", "err", err) + logger.Error("fheEq can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - importCiphertext(accessibleState, result) - // TODO: for testing - err = os.WriteFile("/tmp/eq_result", result.serialize(), 0644) + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheEq inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheEq operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.eq(rhs.ciphertext) + if err != nil { + logger.Error("fheEq failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/eq_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheEq failed to write /tmp/eq_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheEq success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheEq scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarEq(rhs.Uint64()) + if err != nil { + logger.Error("fheEq failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/eq_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheEq scalar failed to write /tmp/eq_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheEq scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } +} + +type fheNe struct{} + +func (e *fheNe) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of le, because comparison costs are currently the same. + le := fheLe{} + return le.RequiredGas(accessibleState, input) +} + +func (e *fheNe) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheEq failed to write /tmp/eq_result", "err", err) + logger.Error("fheNe can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - resultHash := result.getHash() - logger.Info("fheEq success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) - return resultHash[:], nil + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheNe inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheNe operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.ne(rhs.ciphertext) + if err != nil { + logger.Error("fheNe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/ne_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheNe failed to write /tmp/ne_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheNe success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheNe scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarNe(rhs.Uint64()) + if err != nil { + logger.Error("fheNe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/ne_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheNe scalar failed to write /tmp/ne_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheNe scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } } type fheGe struct{} func (e *fheGe) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { - // Implement in terms of lte, because comparison costs are currently the same. - lte := fheLte{} - return lte.RequiredGas(accessibleState, input) + // Implement in terms of le, because comparison costs are currently the same. + le := fheLe{} + return le.RequiredGas(accessibleState, input) } func (e *fheGe) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheGe inputs not verified", "err", err) + logger.Error("fheGe can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheGe operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) - return nil, errors.New(msg) + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheGe inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheGe operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.ge(rhs.ciphertext) + if err != nil { + logger.Error("fheGe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/ge_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheGe failed to write /tmp/ge_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheGe success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheGe scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarGe(rhs.Uint64()) + if err != nil { + logger.Error("fheGe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/ge_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheGe scalar failed to write /tmp/ge_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheGe scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil } +} - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil +type fheGt struct{} + +func (e *fheGt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of le, because comparison costs are currently the same. + le := fheLe{} + return le.RequiredGas(accessibleState, input) +} + +func (e *fheGt) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("fheGt can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheGt inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheGt operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.gt(rhs.ciphertext) + if err != nil { + logger.Error("fheGt failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/gt_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheGt failed to write /tmp/gt_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheGt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheGt scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarGt(rhs.Uint64()) + if err != nil { + logger.Error("fheGt failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/gt_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheGt scalar failed to write /tmp/gt_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheGt scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil } +} + +type fheLt struct{} + +func (e *fheLt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of le, because le and lt costs are currently the same. + le := fheLe{} + return le.RequiredGas(accessibleState, input) +} - result, err := lhs.ciphertext.ge(rhs.ciphertext) +func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheGe failed", "err", err) + logger.Error("fheLt can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - importCiphertext(accessibleState, result) - // TODO: for testing - err = os.WriteFile("/tmp/ge_result", result.serialize(), 0644) + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheLt inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheLt operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.lt(rhs.ciphertext) + if err != nil { + logger.Error("fheLt failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/lt_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheLt failed to write /tmp/lt_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheLt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheLt scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarLt(rhs.Uint64()) + if err != nil { + logger.Error("fheLt failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/lt_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheLt scalar failed to write /tmp/lt_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheLt scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } +} + +type fheMin struct{} + +func (e *fheMin) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + logger := accessibleState.Interpreter().evm.Logger + isScalar, err := isScalarOp(accessibleState, input) + if err != nil { + logger.Error("fheMin/Max RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + var lhs *verifiedCiphertext + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheMin/Max RequiredGas() ciphertext inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + logger.Error("fheMin/Max RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return 0 + } + } else { + lhs, _, err = getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheMin/Max RequiredGas() scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + } + return fheMinMaxGasCosts[lhs.ciphertext.fheUintType] +} + +func (e *fheMin) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheGe failed to write /tmp/ge_result", "err", err) + logger.Error("fheMin can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - resultHash := result.getHash() - logger.Info("fheGt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) - return resultHash[:], nil + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheMin inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheMin operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.min(rhs.ciphertext) + if err != nil { + logger.Error("fheMin failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/min_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheMin failed to write /tmp/min_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheMin success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheMin scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarMin(rhs.Uint64()) + if err != nil { + logger.Error("fheMin failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/min_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheMin scalar failed to write /tmp/min_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheMin scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } } -type fheGt struct{} +type fheMax struct{} -func (e *fheGt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { - // Implement in terms of lte, because comparison costs are currently the same. - lte := fheLte{} - return lte.RequiredGas(accessibleState, input) +func (e *fheMax) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of min, because costs are currently the same. + min := fheMin{} + return min.RequiredGas(accessibleState, input) } -func (e *fheGt) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func (e *fheMax) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + + isScalar, err := isScalarOp(accessibleState, input) if err != nil { - logger.Error("fheGt inputs not verified", "err", err) + logger.Error("fheMax can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheGt operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + if !isScalar { + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheMax inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheMax operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.max(rhs.ciphertext) + if err != nil { + logger.Error("fheMax failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/max_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheMax failed to write /tmp/max_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheMax success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil + + } else { + lhs, rhs, err := getScalarOperands(accessibleState, input) + if err != nil { + logger.Error("fheMax scalar inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.scalarMax(rhs.Uint64()) + if err != nil { + logger.Error("fheMax failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/max_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheMax scalar failed to write /tmp/max_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheMax scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex()) + return resultHash[:], nil + } +} + +type fheNeg struct{} + +func (e *fheNeg) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + logger := accessibleState.Interpreter().evm.Logger + if len(input) != 32 { + logger.Error("fheNeg input needs to contain one 256-bit sized value", "input", hex.EncodeToString(input)) + return 0 + } + ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) + if ct == nil { + logger.Error("fheNeg input not verified", "input", hex.EncodeToString(input)) + return 0 + } + return fheNegNotGasCosts[ct.ciphertext.fheUintType] +} + +func (e *fheNeg) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + + if len(input) != 32 { + msg := "fheMax input needs to contain one 256-bit sized value" + logger.Error(msg, "input", hex.EncodeToString(input)) + return nil, errors.New(msg) + + } + + ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) + if ct == nil { + msg := "fheNeg input not verified" + logger.Error(msg, msg, "input", hex.EncodeToString(input)) return nil, errors.New(msg) } // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + return importRandomCiphertext(accessibleState, ct.ciphertext.fheUintType), nil } - result, err := lhs.ciphertext.gt(rhs.ciphertext) + result, err := ct.ciphertext.neg() if err != nil { - logger.Error("fheGt failed", "err", err) + logger.Error("fheNeg failed", "err", err) return nil, err } importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/gt_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/neg_result", result.serialize(), 0644) if err != nil { - logger.Error("fheGt failed to write /tmp/gt_result", "err", err) + logger.Error("fheNeg failed to write /tmp/neg_result", "err", err) return nil, err } resultHash := result.getHash() - logger.Info("fheGt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + logger.Info("fheNeg success", "ct", ct.ciphertext.getHash().Hex(), "result", resultHash.Hex()) return resultHash[:], nil } -type fheLt struct{} +type fheNot struct{} -func (e *fheLt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { - // Implement in terms of lte, because lte and lt costs are currently the same. - lte := fheLte{} - return lte.RequiredGas(accessibleState, input) +func (e *fheNot) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of neg, because costs are currently the same. + neg := fheNeg{} + return neg.RequiredGas(accessibleState, input) } -func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +func (e *fheNot) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger - lhs, rhs, err := get2VerifiedOperands(accessibleState, input) - if err != nil { - logger.Error("fheLt inputs not verified", "err", err) - return nil, err + + if len(input) != 32 { + msg := "fheMax input needs to contain one 256-bit sized value" + logger.Error(msg, "input", hex.EncodeToString(input)) + return nil, errors.New(msg) + } - if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - msg := "fheLt operand type mismatch" - logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) + if ct == nil { + msg := "fheNot input not verified" + logger.Error(msg, msg, "input", hex.EncodeToString(input)) return nil, errors.New(msg) } // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + return importRandomCiphertext(accessibleState, ct.ciphertext.fheUintType), nil } - result, err := lhs.ciphertext.lt(rhs.ciphertext) + result, err := ct.ciphertext.not() if err != nil { - logger.Error("fheLt failed", "err", err) + logger.Error("fheNot failed", "err", err) return nil, err } importCiphertext(accessibleState, result) // TODO: for testing - err = os.WriteFile("/tmp/lt_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/not_result", result.serialize(), 0644) if err != nil { - logger.Error("fheLt failed to write /tmp/lt_result", "err", err) + logger.Error("fheNot failed to write /tmp/not_result", "err", err) return nil, err } resultHash := result.getHash() - logger.Info("fheLt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + logger.Info("fheNot success", "ct", ct.ciphertext.getHash().Hex(), "result", resultHash.Hex()) return resultHash[:], nil } diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 5e1822907..83793577b 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -452,11 +452,18 @@ func verifyTfheCiphertextInTestMemory(interpreter *EVMInterpreter, ct *tfheCiphe return verifiedCiphertext.ciphertext } -func toPrecompileInput(hashes ...common.Hash) []byte { +func toPrecompileInput(isScalar bool, hashes ...common.Hash) []byte { ret := make([]byte, 0) for _, hash := range hashes { ret = append(ret, hash.Bytes()...) } + var isScalarByte byte + if isScalar { + isScalarByte = 1 + } else { + isScalarByte = 0 + } + ret = append(ret, isScalarByte) return ret } @@ -554,7 +561,7 @@ func TrivialEncrypt(t *testing.T, fheUintType fheUintType) { } } -func FheAdd(t *testing.T, fheUintType fheUintType) { +func FheAdd(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -575,8 +582,13 @@ func FheAdd(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - input := toPrecompileInput(lhsHash, rhsHash) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -591,7 +603,7 @@ func FheAdd(t *testing.T, fheUintType fheUintType) { } } -func FheSub(t *testing.T, fheUintType fheUintType) { +func FheSub(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -612,8 +624,13 @@ func FheSub(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - input := toPrecompileInput(lhsHash, rhsHash) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -628,7 +645,7 @@ func FheSub(t *testing.T, fheUintType fheUintType) { } } -func FheMul(t *testing.T, fheUintType fheUintType) { +func FheMul(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -649,8 +666,13 @@ func FheMul(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - input := toPrecompileInput(lhsHash, rhsHash) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -665,7 +687,7 @@ func FheMul(t *testing.T, fheUintType fheUintType) { } } -func FheBitAnd(t *testing.T, fheUintType fheUintType) { +func FheBitAnd(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -686,23 +708,34 @@ func FheBitAnd(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - input := toPrecompileInput(lhsHash, rhsHash) - out, err := c.Run(state, addr, addr, input, readOnly) - if err != nil { - t.Fatalf(err.Error()) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() } - res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) - if res == nil { - t.Fatalf("output ciphertext is not found in verifiedCiphertexts") - } - decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != expected { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + input := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err == nil { + t.Fatalf("scalar bit and should have failed") + } + } else { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } } } -func FheBitOr(t *testing.T, fheUintType fheUintType) { +func FheBitOr(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -723,8 +756,109 @@ func FheBitOr(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - input := toPrecompileInput(lhsHash, rhsHash) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err == nil { + t.Fatalf("scalar bit or should have failed") + } + } else { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } + } +} + +func FheBitXor(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs ^ rhs + c := &fheBitXor{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err == nil { + t.Fatalf("scalar bit xor should have failed") + } + } else { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } + } +} + +func FheShl(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 2 + case FheUint32: + lhs = 1333337 + rhs = 3 + } + expected := lhs << rhs + c := &fheShl{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -739,7 +873,7 @@ func FheBitOr(t *testing.T, fheUintType fheUintType) { } } -func FheBitXor(t *testing.T, fheUintType fheUintType) { +func FheShr(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -747,21 +881,26 @@ func FheBitXor(t *testing.T, fheUintType fheUintType) { rhs = 1 case FheUint16: lhs = 4283 - rhs = 1337 + rhs = 2 case FheUint32: lhs = 1333337 - rhs = 133337 + rhs = 3 } - expected := lhs ^ rhs - c := &fheBitXor{} + expected := lhs >> rhs + c := &fheShr{} depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - input := toPrecompileInput(lhsHash, rhsHash) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -776,7 +915,7 @@ func FheBitXor(t *testing.T, fheUintType fheUintType) { } } -func FheEq(t *testing.T, fheUintType fheUintType) { +func FheEq(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -796,10 +935,14 @@ func FheEq(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } // lhs == rhs - input1 := toPrecompileInput(lhsHash, rhsHash) + input1 := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -814,7 +957,7 @@ func FheEq(t *testing.T, fheUintType fheUintType) { } } -func FheGe(t *testing.T, fheUintType fheUintType) { +func FheNe(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -827,17 +970,21 @@ func FheGe(t *testing.T, fheUintType fheUintType) { lhs = 1333337 rhs = 133337 } - c := &fheGe{} + c := &fheNe{} depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - - // lhs >= rhs - input1 := toPrecompileInput(lhsHash, rhsHash) + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + // lhs == rhs + input1 := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -850,24 +997,69 @@ func FheGe(t *testing.T, fheUintType fheUintType) { if decrypted.Uint64() != 1 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } +} - // rhs >= lhs - input2 := toPrecompileInput(rhsHash, lhsHash) - out, err = c.Run(state, addr, addr, input2, readOnly) +func FheGe(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheGe{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + // lhs >= rhs + input1 := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) } - res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs >= lhs + input2 := toPrecompileInput(false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } } } -func FheGt(t *testing.T, fheUintType fheUintType) { +func FheGt(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -888,10 +1080,14 @@ func FheGt(t *testing.T, fheUintType fheUintType) { addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() - + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } // lhs > rhs - input1 := toPrecompileInput(lhsHash, rhsHash) + input1 := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -905,23 +1101,89 @@ func FheGt(t *testing.T, fheUintType fheUintType) { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) } - // rhs > lhs - input2 := toPrecompileInput(rhsHash, lhsHash) - out, err = c.Run(state, addr, addr, input2, readOnly) + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs > lhs + input2 := toPrecompileInput(false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } + } +} + +func FheLe(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheLe{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + + // lhs <= rhs + input1 := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) } - res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() + decrypted := res.ciphertext.decrypt() if decrypted.Uint64() != 0 { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs <= lhs + input2 := toPrecompileInput(false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + } } -func FheLte(t *testing.T, fheUintType fheUintType) { +func FheLt(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -934,17 +1196,23 @@ func FheLte(t *testing.T, fheUintType fheUintType) { lhs = 1333337 rhs = 133337 } - c := &fheLte{} + + c := &fheLt{} depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } - // lhs <= rhs - input1 := toPrecompileInput(lhsHash, rhsHash) + // lhs < rhs + input1 := toPrecompileInput(scalar, lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -958,23 +1226,88 @@ func FheLte(t *testing.T, fheUintType fheUintType) { t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) } - // rhs <= lhs - input2 := toPrecompileInput(rhsHash, lhsHash) - out, err = c.Run(state, addr, addr, input2, readOnly) + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs < lhs + input2 := toPrecompileInput(false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + } +} + +func FheMin(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + + c := &fheMin{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + + input := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) } - res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != rhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), rhs) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + input2 := toPrecompileInput(false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != rhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), rhs) + } } } -func FheLt(t *testing.T, fheUintType fheUintType) { +func FheMax(t *testing.T, fheUintType fheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { case FheUint8: @@ -988,18 +1321,22 @@ func FheLt(t *testing.T, fheUintType fheUintType) { rhs = 133337 } - c := &fheLt{} + c := &fheMax{} depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } - // lhs < rhs - input1 := toPrecompileInput(lhsHash, rhsHash) - out, err := c.Run(state, addr, addr, input1, readOnly) + input := toPrecompileInput(scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) } @@ -1008,23 +1345,102 @@ func FheLt(t *testing.T, fheUintType fheUintType) { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } decrypted := res.ciphertext.decrypt() - if decrypted.Uint64() != 0 { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + if decrypted.Uint64() != lhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), lhs) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + input2 := toPrecompileInput(false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != lhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), lhs) + } } +} - // rhs < lhs - input2 := toPrecompileInput(rhsHash, lhsHash) - out, err = c.Run(state, addr, addr, input2, readOnly) +func FheNeg(t *testing.T, fheUintType fheUintType, scalar bool) { + var pt, expected uint64 + switch fheUintType { + case FheUint8: + pt = 2 + expected = uint64(-uint8(pt)) + case FheUint16: + pt = 4283 + expected = uint64(-uint16(pt)) + case FheUint32: + pt = 1333337 + expected = uint64(-uint32(pt)) + } + + c := &fheNeg{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + ptHash := verifyCiphertextInTestMemory(state.interpreter, pt, depth, fheUintType).getHash() + + input := make([]byte, 0) + input = append(input, ptHash.Bytes()...) + out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) } - res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) if res == nil { t.Fatalf("output ciphertext is not found in verifiedCiphertexts") } - decrypted = res.ciphertext.decrypt() - if decrypted.Uint64() != 1 { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheNot(t *testing.T, fheUintType fheUintType, scalar bool) { + var pt, expected uint64 + switch fheUintType { + case FheUint8: + pt = 2 + expected = uint64(^uint8(pt)) + case FheUint16: + pt = 4283 + expected = uint64(^uint16(pt)) + case FheUint32: + pt = 1333337 + expected = uint64(^uint32(pt)) + } + + c := &fheNot{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + ptHash := verifyCiphertextInTestMemory(state.interpreter, pt, depth, fheUintType).getHash() + + input := make([]byte, 0) + input = append(input, ptHash.Bytes()...) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) } } @@ -1084,135 +1500,411 @@ func TestVerifyCiphertextBadCiphertext(t *testing.T) { } func TestFheAdd8(t *testing.T) { - FheAdd(t, FheUint8) + FheAdd(t, FheUint8, false) } func TestFheAdd16(t *testing.T) { - FheAdd(t, FheUint16) + FheAdd(t, FheUint16, false) } func TestFheAdd32(t *testing.T) { - FheAdd(t, FheUint32) + FheAdd(t, FheUint32, false) +} + +func TestFheScalarAdd8(t *testing.T) { + FheAdd(t, FheUint8, true) +} + +func TestFheScalarAdd16(t *testing.T) { + FheAdd(t, FheUint16, true) +} + +func TestFheScalarAdd32(t *testing.T) { + FheAdd(t, FheUint32, true) } func TestFheSub8(t *testing.T) { - FheSub(t, FheUint8) + FheSub(t, FheUint8, false) } func TestFheSub16(t *testing.T) { - FheSub(t, FheUint16) + FheSub(t, FheUint16, false) } func TestFheSub32(t *testing.T) { - FheSub(t, FheUint32) + FheSub(t, FheUint32, false) +} + +func TestFheScalarSub8(t *testing.T) { + FheSub(t, FheUint8, true) +} + +func TestFheScalarSub16(t *testing.T) { + FheSub(t, FheUint16, true) +} + +func TestFheScalarSub32(t *testing.T) { + FheSub(t, FheUint32, true) } func TestFheMul8(t *testing.T) { - FheMul(t, FheUint8) + FheMul(t, FheUint8, false) } func TestFheMul16(t *testing.T) { - FheMul(t, FheUint16) + FheMul(t, FheUint16, false) } func TestFheMul32(t *testing.T) { - FheMul(t, FheUint32) + FheMul(t, FheUint32, false) +} + +func TestFheScalarMul8(t *testing.T) { + FheMul(t, FheUint8, true) +} + +func TestFheScalarMul16(t *testing.T) { + FheMul(t, FheUint16, true) +} + +func TestFheScalarMul32(t *testing.T) { + FheMul(t, FheUint32, true) } func TestFheBitAnd8(t *testing.T) { - FheBitAnd(t, FheUint8) + FheBitAnd(t, FheUint8, false) } func TestFheBitAnd16(t *testing.T) { - FheBitAnd(t, FheUint16) + FheBitAnd(t, FheUint16, false) } func TestFheBitAnd32(t *testing.T) { - FheBitAnd(t, FheUint32) + FheBitAnd(t, FheUint32, false) +} + +func TestFheScalarBitAnd8(t *testing.T) { + FheBitAnd(t, FheUint8, true) +} + +func TestFheScalarBitAnd16(t *testing.T) { + FheBitAnd(t, FheUint16, true) +} + +func TestFheScalarBitAnd32(t *testing.T) { + FheBitAnd(t, FheUint32, true) } func TestFheBitOr8(t *testing.T) { - FheBitOr(t, FheUint8) + FheBitOr(t, FheUint8, false) } func TestFheBitOr16(t *testing.T) { - FheBitOr(t, FheUint16) + FheBitOr(t, FheUint16, false) } func TestFheBitOr32(t *testing.T) { - FheBitOr(t, FheUint32) + FheBitOr(t, FheUint32, false) +} + +func TestFheScalarBitOr8(t *testing.T) { + FheBitOr(t, FheUint8, true) +} + +func TestFheScalarBitOr16(t *testing.T) { + FheBitOr(t, FheUint16, true) +} + +func TestFheScalarBitOr32(t *testing.T) { + FheBitOr(t, FheUint32, true) } func TestFheBitXor8(t *testing.T) { - FheBitXor(t, FheUint8) + FheBitXor(t, FheUint8, false) } func TestFheBitXor16(t *testing.T) { - FheBitXor(t, FheUint16) + FheBitXor(t, FheUint16, false) } func TestFheBitXor32(t *testing.T) { - FheBitXor(t, FheUint32) + FheBitXor(t, FheUint32, false) +} + +func TestFheScalarBitXor8(t *testing.T) { + FheBitXor(t, FheUint8, true) +} + +func TestFheScalarBitXor16(t *testing.T) { + FheBitXor(t, FheUint16, true) +} + +func TestFheScalarBitXor32(t *testing.T) { + FheBitXor(t, FheUint32, true) +} + +func TestFheShl8(t *testing.T) { + FheShl(t, FheUint8, false) +} + +func TestFheShl16(t *testing.T) { + FheShl(t, FheUint16, false) +} + +func TestFheShl32(t *testing.T) { + FheShl(t, FheUint32, false) +} + +func TestFheScalarShl8(t *testing.T) { + FheShl(t, FheUint8, true) +} + +func TestFheScalarShl16(t *testing.T) { + FheShl(t, FheUint16, true) +} + +func TestFheScalarShl32(t *testing.T) { + FheShl(t, FheUint32, true) +} + +func TestFheShr8(t *testing.T) { + FheShr(t, FheUint8, false) +} + +func TestFheShr16(t *testing.T) { + FheShr(t, FheUint16, false) +} + +func TestFheShr32(t *testing.T) { + FheShr(t, FheUint32, false) +} + +func TestFheScalarShr8(t *testing.T) { + FheShr(t, FheUint8, true) +} + +func TestFheScalarShr16(t *testing.T) { + FheShr(t, FheUint16, true) +} + +func TestFheScalarShr32(t *testing.T) { + FheShr(t, FheUint32, true) } func TestFheEq8(t *testing.T) { - FheEq(t, FheUint8) + FheEq(t, FheUint8, false) } func TestFheEq16(t *testing.T) { - FheEq(t, FheUint16) + FheEq(t, FheUint16, false) } func TestFheEq32(t *testing.T) { - FheEq(t, FheUint32) + FheEq(t, FheUint32, false) +} + +func TestFheScalarEq8(t *testing.T) { + FheEq(t, FheUint8, true) +} + +func TestFheScalarEq16(t *testing.T) { + FheEq(t, FheUint16, true) +} + +func TestFheScalarEq32(t *testing.T) { + FheEq(t, FheUint32, true) +} + +func TestFheNe8(t *testing.T) { + FheNe(t, FheUint8, false) +} + +func TestFheNe16(t *testing.T) { + FheNe(t, FheUint16, false) +} + +func TestFheNe32(t *testing.T) { + FheNe(t, FheUint32, false) +} + +func TestFheScalarNe8(t *testing.T) { + FheNe(t, FheUint8, true) +} + +func TestFheScalarNe16(t *testing.T) { + FheNe(t, FheUint16, true) +} + +func TestFheScalarNe32(t *testing.T) { + FheNe(t, FheUint32, true) } func TestFheGe8(t *testing.T) { - FheGe(t, FheUint8) + FheGe(t, FheUint8, false) } func TestFheGe16(t *testing.T) { - FheGe(t, FheUint16) + FheGe(t, FheUint16, false) } func TestFheGe32(t *testing.T) { - FheGe(t, FheUint32) + FheGe(t, FheUint32, false) +} + +func TestFheScalarGe8(t *testing.T) { + FheGe(t, FheUint8, true) +} + +func TestFheScalarGe16(t *testing.T) { + FheGe(t, FheUint16, true) +} + +func TestFheScalarGe32(t *testing.T) { + FheGe(t, FheUint32, true) } func TestFheGt8(t *testing.T) { - FheGt(t, FheUint8) + FheGt(t, FheUint8, false) } func TestFheGt16(t *testing.T) { - FheGt(t, FheUint16) + FheGt(t, FheUint16, false) } func TestFheGt32(t *testing.T) { - FheGt(t, FheUint32) + FheGt(t, FheUint32, false) +} + +func TestFheScalarGt8(t *testing.T) { + FheGt(t, FheUint8, true) +} + +func TestFheScalarGt16(t *testing.T) { + FheGt(t, FheUint16, true) } -func TestFheLte8(t *testing.T) { - FheLte(t, FheUint8) +func TestFheScalarGt32(t *testing.T) { + FheGt(t, FheUint32, true) } -func TestFheLte16(t *testing.T) { - FheLte(t, FheUint16) +func TestFheLe8(t *testing.T) { + FheLe(t, FheUint8, false) } -func TestFheLte32(t *testing.T) { - FheLte(t, FheUint32) +func TestFheLe16(t *testing.T) { + FheLe(t, FheUint16, false) +} + +func TestFheLe32(t *testing.T) { + FheLe(t, FheUint32, false) +} + +func TestFheScalarLe8(t *testing.T) { + FheLe(t, FheUint8, true) +} + +func TestFheScalarLe16(t *testing.T) { + FheLe(t, FheUint16, true) +} + +func TestFheScalarLe32(t *testing.T) { + FheLe(t, FheUint32, true) } func TestFheLt8(t *testing.T) { - FheLt(t, FheUint8) + FheLt(t, FheUint8, false) } func TestFheLt16(t *testing.T) { - FheLt(t, FheUint16) + FheLt(t, FheUint16, false) } func TestFheLt32(t *testing.T) { - FheLt(t, FheUint32) + FheLt(t, FheUint32, false) +} + +func TestFheScalarLt8(t *testing.T) { + FheLt(t, FheUint8, true) +} + +func TestFheScalarLt16(t *testing.T) { + FheLt(t, FheUint16, true) +} + +func TestFheScalarLt32(t *testing.T) { + FheLt(t, FheUint32, true) +} + +func TestFheMin8(t *testing.T) { + FheMin(t, FheUint8, false) +} + +func TestFheMin16(t *testing.T) { + FheMin(t, FheUint16, false) +} + +func TestFheMin32(t *testing.T) { + FheMin(t, FheUint32, false) +} + +func TestFheScalarMin8(t *testing.T) { + FheMin(t, FheUint8, true) +} + +func TestFheScalarMin16(t *testing.T) { + FheMin(t, FheUint16, true) +} + +func TestFheScalarMin32(t *testing.T) { + FheMin(t, FheUint32, true) +} + +func TestFheMax8(t *testing.T) { + FheMax(t, FheUint8, false) +} + +func TestFheMax16(t *testing.T) { + FheMax(t, FheUint16, false) +} + +func TestFheMax32(t *testing.T) { + FheMax(t, FheUint32, false) +} + +func TestFheNeg8(t *testing.T) { + FheNeg(t, FheUint8, false) +} + +func TestFheNeg16(t *testing.T) { + FheNeg(t, FheUint16, false) +} + +func TestFheNeg32(t *testing.T) { + FheNeg(t, FheUint32, false) +} + +func TestFheNot8(t *testing.T) { + FheNot(t, FheUint8, false) +} + +func TestFheNot16(t *testing.T) { + FheNot(t, FheUint16, false) +} + +func TestFheNot32(t *testing.T) { + FheNot(t, FheUint32, false) +} + +func TestFheScalarMax8(t *testing.T) { + FheMax(t, FheUint8, true) +} + +func TestFheScalarMax16(t *testing.T) { + FheMax(t, FheUint16, true) +} + +func TestFheScalarMax32(t *testing.T) { + FheMax(t, FheUint32, true) } func TestUnknownCiphertextHandle(t *testing.T) { diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index 22faa5a8b..888106fb6 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -213,6 +213,39 @@ void* add_fhe_uint32(void* ct1, void* ct2, void* sks) return result; } +void* scalar_add_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_add(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_add_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_add(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_add_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_add(ct, pt, &result); + assert(r == 0); + return result; +} + void* sub_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -246,6 +279,39 @@ void* sub_fhe_uint32(void* ct1, void* ct2, void* sks) return result; } +void* scalar_sub_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_sub(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_sub_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_sub(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_sub_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_sub(ct, pt, &result); + assert(r == 0); + return result; +} + void* mul_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -279,6 +345,39 @@ void* mul_fhe_uint32(void* ct1, void* ct2, void* sks) return result; } +void* scalar_mul_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_mul(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_mul_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_mul(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_mul_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_mul(ct, pt, &result); + assert(r == 0); + return result; +} + void* bitand_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -378,337 +477,892 @@ void* bitxor_fhe_uint32(void* ct1, void* ct2, void* sks) return result; } -void* eq_fhe_uint8(void* ct1, void* ct2, void* sks) +void* shl_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_eq(ct1, ct2, &result); + const int r = fhe_uint8_shl(ct1, ct2, &result); assert(r == 0); return result; } -void* eq_fhe_uint16(void* ct1, void* ct2, void* sks) +void* shl_fhe_uint16(void* ct1, void* ct2, void* sks) { FheUint16* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_eq(ct1, ct2, &result); + const int r = fhe_uint16_shl(ct1, ct2, &result); assert(r == 0); return result; } -void* eq_fhe_uint32(void* ct1, void* ct2, void* sks) +void* shl_fhe_uint32(void* ct1, void* ct2, void* sks) { FheUint32* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_eq(ct1, ct2, &result); + const int r = fhe_uint32_shl(ct1, ct2, &result); assert(r == 0); return result; } -void* ge_fhe_uint8(void* ct1, void* ct2, void* sks) +void* scalar_shl_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_ge(ct1, ct2, &result); + const int r = fhe_uint8_scalar_shl(ct, pt, &result); assert(r == 0); return result; } -void* ge_fhe_uint16(void* ct1, void* ct2, void* sks) +void* scalar_shl_fhe_uint16(void* ct, uint16_t pt, void* sks) { FheUint16* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_ge(ct1, ct2, &result); + const int r = fhe_uint16_scalar_shl(ct, pt, &result); assert(r == 0); return result; } -void* ge_fhe_uint32(void* ct1, void* ct2, void* sks) +void* scalar_shl_fhe_uint32(void* ct, uint32_t pt, void* sks) { FheUint32* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_ge(ct1, ct2, &result); + const int r = fhe_uint32_scalar_shl(ct, pt, &result); assert(r == 0); return result; } -void* gt_fhe_uint8(void* ct1, void* ct2, void* sks) +void* shr_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_gt(ct1, ct2, &result); + const int r = fhe_uint8_shr(ct1, ct2, &result); assert(r == 0); return result; } -void* gt_fhe_uint16(void* ct1, void* ct2, void* sks) +void* shr_fhe_uint16(void* ct1, void* ct2, void* sks) { FheUint16* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_gt(ct1, ct2, &result); + const int r = fhe_uint16_shr(ct1, ct2, &result); assert(r == 0); return result; } -void* gt_fhe_uint32(void* ct1, void* ct2, void* sks) +void* shr_fhe_uint32(void* ct1, void* ct2, void* sks) { FheUint32* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_gt(ct1, ct2, &result); + const int r = fhe_uint32_shr(ct1, ct2, &result); assert(r == 0); return result; } -void* le_fhe_uint8(void* ct1, void* ct2, void* sks) +void* scalar_shr_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_le(ct1, ct2, &result); + const int r = fhe_uint8_scalar_shr(ct, pt, &result); assert(r == 0); return result; } -void* le_fhe_uint16(void* ct1, void* ct2, void* sks) +void* scalar_shr_fhe_uint16(void* ct, uint16_t pt, void* sks) { FheUint16* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_le(ct1, ct2, &result); + const int r = fhe_uint16_scalar_shr(ct, pt, &result); assert(r == 0); return result; } -void* le_fhe_uint32(void* ct1, void* ct2, void* sks) +void* scalar_shr_fhe_uint32(void* ct, uint32_t pt, void* sks) { FheUint32* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_le(ct1, ct2, &result); + const int r = fhe_uint32_scalar_shr(ct, pt, &result); assert(r == 0); return result; } -void* lt_fhe_uint8(void* ct1, void* ct2, void* sks) +void* eq_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_lt(ct1, ct2, &result); + const int r = fhe_uint8_eq(ct1, ct2, &result); assert(r == 0); return result; } -void* lt_fhe_uint16(void* ct1, void* ct2, void* sks) +void* eq_fhe_uint16(void* ct1, void* ct2, void* sks) { FheUint16* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_lt(ct1, ct2, &result); + const int r = fhe_uint16_eq(ct1, ct2, &result); assert(r == 0); return result; } -void* lt_fhe_uint32(void* ct1, void* ct2, void* sks) +void* eq_fhe_uint32(void* ct1, void* ct2, void* sks) { FheUint32* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_lt(ct1, ct2, &result); + const int r = fhe_uint32_eq(ct1, ct2, &result); assert(r == 0); return result; } -uint8_t decrypt_fhe_uint8(void* cks, void* ct) +void* scalar_eq_fhe_uint8(void* ct, uint8_t pt, void* sks) { - uint8_t res = 0; - const int r = fhe_uint8_decrypt(ct, cks, &res); - assert(r == 0); - return res; -} + FheUint8* result = NULL; -uint16_t decrypt_fhe_uint16(void* cks, void* ct) -{ - uint16_t res = 0; - const int r = fhe_uint16_decrypt(ct, cks, &res); - assert(r == 0); - return res; -} + checked_set_server_key(sks); -uint32_t decrypt_fhe_uint32(void* cks, void* ct) -{ - uint32_t res = 0; - const int r = fhe_uint32_decrypt(ct, cks, &res); + const int r = fhe_uint8_scalar_eq(ct, pt, &result); assert(r == 0); - return res; + return result; } -void* public_key_encrypt_fhe_uint8(void* pks, uint8_t value) { - CompactFheUint8List* list = NULL; - FheUint8* ct = NULL; - - int r = compact_fhe_uint8_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); - assert(r == 0); +void* scalar_eq_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; - r = compact_fhe_uint8_list_expand(list, &ct, 1); - assert(r == 0); + checked_set_server_key(sks); - r = compact_fhe_uint8_list_destroy(list); + const int r = fhe_uint16_scalar_eq(ct, pt, &result); assert(r == 0); - - return ct; + return result; } -void* public_key_encrypt_fhe_uint16(void* pks, uint16_t value) { - CompactFheUint16List* list = NULL; - FheUint16* ct = NULL; - - int r = compact_fhe_uint16_list_try_encrypt_with_compact_public_key_u16(&value, 1, pks, &list); - assert(r == 0); +void* scalar_eq_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; - r = compact_fhe_uint16_list_expand(list, &ct, 1); - assert(r == 0); + checked_set_server_key(sks); - r = compact_fhe_uint16_list_destroy(list); + const int r = fhe_uint32_scalar_eq(ct, pt, &result); assert(r == 0); - - return ct; + return result; } -void* public_key_encrypt_fhe_uint32(void* pks, uint32_t value) { - CompactFheUint32List* list = NULL; - FheUint32* ct = NULL; - - int r = compact_fhe_uint32_list_try_encrypt_with_compact_public_key_u32(&value, 1, pks, &list); - assert(r == 0); +void* ne_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; - r = compact_fhe_uint32_list_expand(list, &ct, 1); - assert(r == 0); + checked_set_server_key(sks); - r = compact_fhe_uint32_list_destroy(list); + const int r = fhe_uint8_ne(ct1, ct2, &result); assert(r == 0); - - return ct; + return result; } -void* trivial_encrypt_fhe_uint8(void* sks, uint8_t value) { - FheUint8* ct = NULL; +void* ne_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; checked_set_server_key(sks); - int r = fhe_uint8_try_encrypt_trivial_u8(value, &ct); - assert(r == 0); - - return ct; + const int r = fhe_uint16_ne(ct1, ct2, &result); + assert(r == 0); + return result; } -void* trivial_encrypt_fhe_uint16(void* sks, uint16_t value) { - FheUint16* ct = NULL; +void* ne_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; checked_set_server_key(sks); - int r = fhe_uint16_try_encrypt_trivial_u16(value, &ct); - assert(r == 0); - - return ct; + const int r = fhe_uint32_ne(ct1, ct2, &result); + assert(r == 0); + return result; } -void* trivial_encrypt_fhe_uint32(void* sks, uint32_t value) { - FheUint32* ct = NULL; +void* scalar_ne_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; checked_set_server_key(sks); - int r = fhe_uint32_try_encrypt_trivial_u32(value, &ct); - assert(r == 0); - - return ct; + const int r = fhe_uint8_scalar_ne(ct, pt, &result); + assert(r == 0); + return result; } -void public_key_encrypt_and_serialize_fhe_uint8_list(void* pks, uint8_t value, Buffer* out) { - CompactFheUint8List* list = NULL; +void* scalar_ne_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; - int r = compact_fhe_uint8_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); - assert(r == 0); + checked_set_server_key(sks); - r = compact_fhe_uint8_list_serialize(list, out); + const int r = fhe_uint16_scalar_ne(ct, pt, &result); assert(r == 0); + return result; } -void public_key_encrypt_and_serialize_fhe_uint16_list(void* pks, uint16_t value, Buffer* out) { - CompactFheUint16List* list = NULL; +void* scalar_ne_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; - int r = compact_fhe_uint16_list_try_encrypt_with_compact_public_key_u16(&value, 1, pks, &list); - assert(r == 0); + checked_set_server_key(sks); - r = compact_fhe_uint16_list_serialize(list, out); + const int r = fhe_uint32_scalar_ne(ct, pt, &result); assert(r == 0); + return result; } -void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, Buffer* out) { - CompactFheUint32List* list = NULL; +void* ge_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; - int r = compact_fhe_uint32_list_try_encrypt_with_compact_public_key_u32(&value, 1, pks, &list); - assert(r == 0); + checked_set_server_key(sks); - r = compact_fhe_uint32_list_serialize(list, out); + const int r = fhe_uint8_ge(ct1, ct2, &result); assert(r == 0); + return result; } -void* cast_8_16(void* ct, void* sks) { +void* ge_fhe_uint16(void* ct1, void* ct2, void* sks) +{ FheUint16* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_cast_into_fhe_uint16(ct, &result); + const int r = fhe_uint16_ge(ct1, ct2, &result); assert(r == 0); return result; } -void* cast_8_32(void* ct, void* sks) { +void* ge_fhe_uint32(void* ct1, void* ct2, void* sks) +{ FheUint32* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_cast_into_fhe_uint32(ct, &result); + const int r = fhe_uint32_ge(ct1, ct2, &result); assert(r == 0); return result; } -void* cast_16_8(void* ct, void* sks) { +void* scalar_ge_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ FheUint8* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_cast_into_fhe_uint8(ct, &result); + const int r = fhe_uint8_scalar_ge(ct, pt, &result); assert(r == 0); return result; } -void* cast_16_32(void* ct, void* sks) { +void* scalar_ge_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_ge(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_ge_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_ge(ct, pt, &result); + assert(r == 0); + return result; +} + +void* gt_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_gt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* gt_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_gt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* gt_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_gt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* scalar_gt_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_gt(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_gt_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_gt(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_gt_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_gt(ct, pt, &result); + assert(r == 0); + return result; +} + +void* le_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_le(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* le_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_le(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* le_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_le(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* lt_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_lt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* scalar_le_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_le(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_le_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_le(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_le_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_le(ct, pt, &result); + assert(r == 0); + return result; +} + +void* lt_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_lt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* lt_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_lt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* scalar_lt_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_lt(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_lt_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_lt(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_lt_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_lt(ct, pt, &result); + assert(r == 0); + return result; +} + +void* min_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_min(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* min_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_min(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* min_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_min(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* scalar_min_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_min(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_min_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_min(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_min_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_min(ct, pt, &result); + assert(r == 0); + return result; +} + +void* max_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_max(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* max_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_max(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* max_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_max(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* scalar_max_fhe_uint8(void* ct, uint8_t pt, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_scalar_max(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_max_fhe_uint16(void* ct, uint16_t pt, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_scalar_max(ct, pt, &result); + assert(r == 0); + return result; +} + +void* scalar_max_fhe_uint32(void* ct, uint32_t pt, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_scalar_max(ct, pt, &result); + assert(r == 0); + return result; +} + +void* neg_fhe_uint8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_neg(ct, &result); + assert(r == 0); + return result; +} + +void* neg_fhe_uint16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_neg(ct, &result); + assert(r == 0); + return result; +} + +void* neg_fhe_uint32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_neg(ct, &result); + assert(r == 0); + return result; +} + +void* not_fhe_uint8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_not(ct, &result); + assert(r == 0); + return result; +} + +void* not_fhe_uint16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_not(ct, &result); + assert(r == 0); + return result; +} + +void* not_fhe_uint32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_not(ct, &result); + assert(r == 0); + return result; +} + +uint8_t decrypt_fhe_uint8(void* cks, void* ct) +{ + uint8_t res = 0; + const int r = fhe_uint8_decrypt(ct, cks, &res); + assert(r == 0); + return res; +} + +uint16_t decrypt_fhe_uint16(void* cks, void* ct) +{ + uint16_t res = 0; + const int r = fhe_uint16_decrypt(ct, cks, &res); + assert(r == 0); + return res; +} + +uint32_t decrypt_fhe_uint32(void* cks, void* ct) +{ + uint32_t res = 0; + const int r = fhe_uint32_decrypt(ct, cks, &res); + assert(r == 0); + return res; +} + +void* public_key_encrypt_fhe_uint8(void* pks, uint8_t value) { + CompactFheUint8List* list = NULL; + FheUint8* ct = NULL; + + int r = compact_fhe_uint8_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint8_list_expand(list, &ct, 1); + assert(r == 0); + + r = compact_fhe_uint8_list_destroy(list); + assert(r == 0); + + return ct; +} + +void* public_key_encrypt_fhe_uint16(void* pks, uint16_t value) { + CompactFheUint16List* list = NULL; + FheUint16* ct = NULL; + + int r = compact_fhe_uint16_list_try_encrypt_with_compact_public_key_u16(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint16_list_expand(list, &ct, 1); + assert(r == 0); + + r = compact_fhe_uint16_list_destroy(list); + assert(r == 0); + + return ct; +} + +void* public_key_encrypt_fhe_uint32(void* pks, uint32_t value) { + CompactFheUint32List* list = NULL; + FheUint32* ct = NULL; + + int r = compact_fhe_uint32_list_try_encrypt_with_compact_public_key_u32(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint32_list_expand(list, &ct, 1); + assert(r == 0); + + r = compact_fhe_uint32_list_destroy(list); + assert(r == 0); + + return ct; +} + +void* trivial_encrypt_fhe_uint8(void* sks, uint8_t value) { + FheUint8* ct = NULL; + + checked_set_server_key(sks); + + int r = fhe_uint8_try_encrypt_trivial_u8(value, &ct); + assert(r == 0); + + return ct; +} + +void* trivial_encrypt_fhe_uint16(void* sks, uint16_t value) { + FheUint16* ct = NULL; + + checked_set_server_key(sks); + + int r = fhe_uint16_try_encrypt_trivial_u16(value, &ct); + assert(r == 0); + + return ct; +} + +void* trivial_encrypt_fhe_uint32(void* sks, uint32_t value) { + FheUint32* ct = NULL; + + checked_set_server_key(sks); + + int r = fhe_uint32_try_encrypt_trivial_u32(value, &ct); + assert(r == 0); + + return ct; +} + +void public_key_encrypt_and_serialize_fhe_uint8_list(void* pks, uint8_t value, Buffer* out) { + CompactFheUint8List* list = NULL; + + int r = compact_fhe_uint8_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint8_list_serialize(list, out); + assert(r == 0); +} + +void public_key_encrypt_and_serialize_fhe_uint16_list(void* pks, uint16_t value, Buffer* out) { + CompactFheUint16List* list = NULL; + + int r = compact_fhe_uint16_list_try_encrypt_with_compact_public_key_u16(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint16_list_serialize(list, out); + assert(r == 0); +} + +void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, Buffer* out) { + CompactFheUint32List* list = NULL; + + int r = compact_fhe_uint32_list_try_encrypt_with_compact_public_key_u32(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint32_list_serialize(list, out); + assert(r == 0); +} + +void* cast_8_16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint16(ct, &result); + assert(r == 0); + return result; +} + +void* cast_8_32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint32(ct, &result); + assert(r == 0); + return result; +} + +void* cast_16_8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_cast_into_fhe_uint8(ct, &result); + assert(r == 0); + return result; +} + +void* cast_16_32(void* ct, void* sks) { FheUint32* result = NULL; checked_set_server_key(sks); @@ -971,11 +1625,32 @@ func (lhs *tfheCiphertext) add(rhs *tfheCiphertext) (*tfheCiphertext, error) { res.fheUintType = lhs.fheUintType switch lhs.fheUintType { case FheUint8: - res.setPtr(C.add_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + res.setPtr(C.add_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.add_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.add_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) scalarAdd(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar add on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_add_fhe_uint8(lhs.ptr, pt, sks)) case FheUint16: - res.setPtr(C.add_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_add_fhe_uint16(lhs.ptr, pt, sks)) case FheUint32: - res.setPtr(C.add_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_add_fhe_uint32(lhs.ptr, pt, sks)) } return res, nil } @@ -1002,6 +1677,27 @@ func (lhs *tfheCiphertext) sub(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) scalarSub(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar sub on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_sub_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_sub_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_sub_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot mul on a non-initialized ciphertext") @@ -1024,6 +1720,27 @@ func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) scalarMul(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar mul on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_mul_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_mul_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_mul_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) bitand(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot bitwise AND on a non-initialized ciphertext") @@ -1090,6 +1807,92 @@ func (lhs *tfheCiphertext) bitxor(rhs *tfheCiphertext) (*tfheCiphertext, error) return res, nil } +func (lhs *tfheCiphertext) shl(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot shl on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.shl_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.shl_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.shl_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) scalarShl(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar shl on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_shl_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_shl_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_shl_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) shr(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot shr on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.shr_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.shr_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.shr_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) scalarShr(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar shr on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_shr_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_shr_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_shr_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) eq(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot eq on a non-initialized ciphertext") @@ -1112,6 +1915,70 @@ func (lhs *tfheCiphertext) eq(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) scalarEq(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar eq on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_eq_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_eq_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_eq_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) ne(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot ne on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.ne_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.ne_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.ne_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) scalarNe(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar ne on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_ne_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_ne_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_ne_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) ge(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot ge on a non-initialized ciphertext") @@ -1134,6 +2001,27 @@ func (lhs *tfheCiphertext) ge(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) scalarGe(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar ge on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_ge_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_ge_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_ge_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) gt(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot gt on a non-initialized ciphertext") @@ -1156,9 +2044,30 @@ func (lhs *tfheCiphertext) gt(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } -func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) (*tfheCiphertext, error) { +func (lhs *tfheCiphertext) scalarGt(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar gt on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_gt_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_gt_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_gt_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) le(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { - panic("cannot lte on a non-initialized ciphertext") + panic("cannot le on a non-initialized ciphertext") } if lhs.fheUintType != rhs.fheUintType { @@ -1178,6 +2087,27 @@ func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) scalarLe(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar le on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_le_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_le_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_le_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot lt on a non-initialized ciphertext") @@ -1200,6 +2130,149 @@ func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) scalarLt(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar lt on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_lt_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_lt_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_lt_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) min(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot min on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.min_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.min_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.min_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) scalarMin(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar min on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_min_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_min_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_min_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) max(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot max on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.max_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.max_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.max_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) scalarMax(rhs uint64) (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot scalar max on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + pt := C.uint8_t(rhs) + res.setPtr(C.scalar_max_fhe_uint8(lhs.ptr, pt, sks)) + case FheUint16: + pt := C.uint16_t(rhs) + res.setPtr(C.scalar_max_fhe_uint16(lhs.ptr, pt, sks)) + case FheUint32: + pt := C.uint32_t(rhs) + res.setPtr(C.scalar_max_fhe_uint32(lhs.ptr, pt, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) neg() (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot neg on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.neg_fhe_uint8(lhs.ptr, sks)) + case FheUint16: + res.setPtr(C.neg_fhe_uint16(lhs.ptr, sks)) + case FheUint32: + res.setPtr(C.neg_fhe_uint32(lhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) not() (*tfheCiphertext, error) { + if !lhs.availableForOps() { + panic("cannot not on a non-initialized ciphertext") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.not_fhe_uint8(lhs.ptr, sks)) + case FheUint16: + res.setPtr(C.not_fhe_uint16(lhs.ptr, sks)) + case FheUint32: + res.setPtr(C.not_fhe_uint32(lhs.ptr, sks)) + } + return res, nil +} + func (ct *tfheCiphertext) castTo(castToType fheUintType) (*tfheCiphertext, error) { if !ct.availableForOps() { panic("cannot cast a non-initialized ciphertext") diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index 7a798b34e..3231d5e96 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -211,6 +211,29 @@ func TfheAdd(t *testing.T, fheUintType fheUintType) { } } +func TfheScalarAdd(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + expected := new(big.Int).Add(&a, &b) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarAdd(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + func TfheSub(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { @@ -236,6 +259,29 @@ func TfheSub(t *testing.T, fheUintType fheUintType) { } } +func TfheScalarSub(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + expected := new(big.Int).Sub(&a, &b) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarSub(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + func TfheMul(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { @@ -261,113 +307,554 @@ func TfheMul(t *testing.T, fheUintType fheUintType) { } } -func TfheBitAnd(t *testing.T, fheUintType fheUintType) { +func TfheScalarMul(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := new(big.Int).Mul(&a, &b) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarMul(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + +func TfheBitAnd(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := a.Uint64() & b.Uint64() + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.bitand(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheBitOr(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := a.Uint64() | b.Uint64() + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.bitor(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheBitXor(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := a.Uint64() ^ b.Uint64() + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.bitxor(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheShl(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := new(big.Int).Lsh(&a, uint(b.Uint64())) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.shl(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + +func TfheScalarShl(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := new(big.Int).Lsh(&a, uint(b.Uint64())) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarShl(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + +func TfheShr(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := new(big.Int).Rsh(&a, uint(b.Uint64())) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.shr(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + +func TfheScalarShr(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := new(big.Int).Rsh(&a, uint(b.Uint64())) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarShr(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected.Uint64() { + t.Fatalf("%d != %d", expected.Uint64(), res.Uint64()) + } +} + +func TfheEq(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(2) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(137) + } + var expected uint64 + expectedBool := a.Uint64() == b.Uint64() + if expectedBool { + expected = 1 + } else { + expected = 0 + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.eq(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheScalarEq(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + var expected uint64 + expectedBool := a.Uint64() == b.Uint64() + if expectedBool { + expected = 1 + } else { + expected = 0 + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarEq(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheNe(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(2) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(137) + } + var expected uint64 + expectedBool := a.Uint64() != b.Uint64() + if expectedBool { + expected = 1 + } else { + expected = 0 + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.ne(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheScalarNe(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + var expected uint64 + expectedBool := a.Uint64() != b.Uint64() + if expectedBool { + expected = 1 + } else { + expected = 0 + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes, _ := ctA.scalarNe(b.Uint64()) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheGe(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes1, _ := ctA.ge(ctB) + ctRes2, _ := ctB.ge(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if res2.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res2.Uint64()) + } +} + +func TfheScalarGe(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes1, _ := ctA.scalarGe(b.Uint64()) + res1 := ctRes1.decrypt() + if res1.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } +} + +func TfheGt(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes1, _ := ctA.gt(ctB) + ctRes2, _ := ctB.gt(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if res2.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res2.Uint64()) + } +} + +func TfheScalarGt(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctRes1, _ := ctA.scalarGt(b.Uint64()) + res1 := ctRes1.decrypt() + if res1.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } +} + +func TfheLe(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes1, _ := ctA.le(ctB) + ctRes2, _ := ctB.le(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if res2.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res2.Uint64()) + } +} + +func TfheScalarLe(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: a.SetUint64(2) b.SetUint64(1) case FheUint16: - a.SetUint64(169) - b.SetUint64(5) + a.SetUint64(4283) + b.SetUint64(1337) case FheUint32: - a.SetUint64(137) - b.SetInt64(17) + a.SetUint64(1333337) + b.SetUint64(133337) } - expected := a.Uint64() & b.Uint64() ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) - ctB := new(tfheCiphertext) - ctB.encrypt(b, fheUintType) - ctRes, _ := ctA.bitand(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { - t.Fatalf("%d != %d", expected, res.Uint64()) + ctRes1, _ := ctA.scalarLe(b.Uint64()) + res1 := ctRes1.decrypt() + if res1.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res1.Uint64()) } } -func TfheBitOr(t *testing.T, fheUintType fheUintType) { +func TfheLt(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: a.SetUint64(2) b.SetUint64(1) case FheUint16: - a.SetUint64(169) - b.SetUint64(5) + a.SetUint64(4283) + b.SetUint64(1337) case FheUint32: - a.SetUint64(137) - b.SetInt64(17) + a.SetUint64(1333337) + b.SetUint64(133337) } - expected := a.Uint64() | b.Uint64() ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) - ctRes, _ := ctA.bitor(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { - t.Fatalf("%d != %d", expected, res.Uint64()) + ctRes1, _ := ctA.lt(ctB) + ctRes2, _ := ctB.lt(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if res2.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res2.Uint64()) } } -func TfheBitXor(t *testing.T, fheUintType fheUintType) { +func TfheScalarLt(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: a.SetUint64(2) b.SetUint64(1) case FheUint16: - a.SetUint64(169) - b.SetUint64(5) + a.SetUint64(4283) + b.SetUint64(1337) case FheUint32: - a.SetUint64(137) - b.SetInt64(17) + a.SetUint64(1333337) + b.SetUint64(133337) } - expected := a.Uint64() ^ b.Uint64() ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) - ctB := new(tfheCiphertext) - ctB.encrypt(b, fheUintType) - ctRes, _ := ctA.bitxor(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { - t.Fatalf("%d != %d", expected, res.Uint64()) + ctRes1, _ := ctA.scalarLt(b.Uint64()) + res1 := ctRes1.decrypt() + if res1.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res1.Uint64()) } } -func TfheEq(t *testing.T, fheUintType fheUintType) { +func TfheMin(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: a.SetUint64(2) - b.SetUint64(2) + b.SetUint64(1) case FheUint16: - a.SetUint64(169) - b.SetUint64(5) + a.SetUint64(4283) + b.SetUint64(1337) case FheUint32: - a.SetUint64(137) - b.SetInt64(137) - } - var expected uint64 - expectedBool := a.Uint64() == b.Uint64() - if expectedBool { - expected = 1 - } else { - expected = 0 + a.SetUint64(1333337) + b.SetUint64(133337) } ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) - ctRes, _ := ctA.eq(ctB) - res := ctRes.decrypt() - if res.Uint64() != expected { - t.Fatalf("%d != %d", expected, res.Uint64()) + ctRes1, _ := ctA.min(ctB) + ctRes2, _ := ctB.min(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != b.Uint64() { + t.Fatalf("%d != %d", b.Uint64(), res1.Uint64()) + } + if res2.Uint64() != b.Uint64() { + t.Fatalf("%d != %d", b.Uint64(), res2.Uint64()) } } -func TfheGe(t *testing.T, fheUintType fheUintType) { +func TfheScalarMin(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: @@ -382,21 +869,14 @@ func TfheGe(t *testing.T, fheUintType fheUintType) { } ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) - ctB := new(tfheCiphertext) - ctB.encrypt(b, fheUintType) - ctRes1, _ := ctA.ge(ctB) - ctRes2, _ := ctB.ge(ctA) + ctRes1, _ := ctA.scalarMin(b.Uint64()) res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 1 { + if res1.Uint64() != b.Uint64() { t.Fatalf("%d != %d", 0, res1.Uint64()) } - if res2.Uint64() != 0 { - t.Fatalf("%d != %d", 0, res2.Uint64()) - } } -func TfheGt(t *testing.T, fheUintType fheUintType) { +func TfheMax(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: @@ -413,19 +893,19 @@ func TfheGt(t *testing.T, fheUintType fheUintType) { ctA.encrypt(a, fheUintType) ctB := new(tfheCiphertext) ctB.encrypt(b, fheUintType) - ctRes1, _ := ctA.gt(ctB) - ctRes2, _ := ctB.gt(ctA) + ctRes1, _ := ctA.max(ctB) + ctRes2, _ := ctB.max(ctA) res1 := ctRes1.decrypt() res2 := ctRes2.decrypt() - if res1.Uint64() != 1 { - t.Fatalf("%d != %d", 0, res1.Uint64()) + if res1.Uint64() != a.Uint64() { + t.Fatalf("%d != %d", b.Uint64(), res1.Uint64()) } - if res2.Uint64() != 0 { - t.Fatalf("%d != %d", 0, res2.Uint64()) + if res2.Uint64() != a.Uint64() { + t.Fatalf("%d != %d", b.Uint64(), res2.Uint64()) } } -func TfheLte(t *testing.T, fheUintType fheUintType) { +func TfheScalarMax(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { case FheUint8: @@ -440,46 +920,58 @@ func TfheLte(t *testing.T, fheUintType fheUintType) { } ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) - ctB := new(tfheCiphertext) - ctB.encrypt(b, fheUintType) - ctRes1, _ := ctA.lte(ctB) - ctRes2, _ := ctB.lte(ctA) + ctRes1, _ := ctA.scalarMax(b.Uint64()) res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 0 { + if res1.Uint64() != a.Uint64() { t.Fatalf("%d != %d", 0, res1.Uint64()) } - if res2.Uint64() != 1 { - t.Fatalf("%d != %d", 0, res2.Uint64()) - } } -func TfheLt(t *testing.T, fheUintType fheUintType) { - var a, b big.Int +func TfheNeg(t *testing.T, fheUintType fheUintType) { + var a big.Int + var expected uint64 + switch fheUintType { case FheUint8: a.SetUint64(2) - b.SetUint64(1) + expected = uint64(-uint8(a.Uint64())) case FheUint16: a.SetUint64(4283) - b.SetUint64(1337) + expected = uint64(-uint16(a.Uint64())) case FheUint32: a.SetUint64(1333337) - b.SetUint64(133337) + expected = uint64(-uint32(a.Uint64())) } ctA := new(tfheCiphertext) ctA.encrypt(a, fheUintType) - ctB := new(tfheCiphertext) - ctB.encrypt(b, fheUintType) - ctRes1, _ := ctA.lte(ctB) - ctRes2, _ := ctB.lte(ctA) + ctRes1, _ := ctA.neg() res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1.Uint64() != 0 { - t.Fatalf("%d != %d", 0, res1.Uint64()) + if res1.Uint64() != expected { + t.Fatalf("%d != %d", res1.Uint64(), expected) } - if res2.Uint64() != 1 { - t.Fatalf("%d != %d", 0, res2.Uint64()) +} + +func TfheNot(t *testing.T, fheUintType fheUintType) { + var a big.Int + var expected uint64 + switch fheUintType { + case FheUint8: + a.SetUint64(2) + expected = uint64(^uint8(a.Uint64())) + case FheUint16: + a.SetUint64(4283) + expected = uint64(^uint16(a.Uint64())) + case FheUint32: + a.SetUint64(1333337) + expected = uint64(^uint32(a.Uint64())) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + + ctRes1, _ := ctA.not() + res1 := ctRes1.decrypt() + if res1.Uint64() != expected { + t.Fatalf("%d != %d", res1.Uint64(), expected) } } @@ -629,6 +1121,18 @@ func TestTfheAdd32(t *testing.T) { TfheAdd(t, FheUint32) } +func TestTfheScalarAdd8(t *testing.T) { + TfheScalarAdd(t, FheUint8) +} + +func TestTfheScalarAdd16(t *testing.T) { + TfheScalarAdd(t, FheUint16) +} + +func TestTfheScalarAdd32(t *testing.T) { + TfheScalarAdd(t, FheUint32) +} + func TestTfheSub8(t *testing.T) { TfheSub(t, FheUint8) } @@ -641,6 +1145,18 @@ func TestTfheSub32(t *testing.T) { TfheSub(t, FheUint32) } +func TestTfheScalarSub8(t *testing.T) { + TfheScalarSub(t, FheUint8) +} + +func TestTfheScalarSub16(t *testing.T) { + TfheScalarSub(t, FheUint16) +} + +func TestTfheScalarSub32(t *testing.T) { + TfheScalarSub(t, FheUint32) +} + func TestTfheMul8(t *testing.T) { TfheMul(t, FheUint8) } @@ -653,6 +1169,18 @@ func TestTfheMul32(t *testing.T) { TfheMul(t, FheUint32) } +func TestTfheScalarMul8(t *testing.T) { + TfheScalarMul(t, FheUint8) +} + +func TestTfheScalarMul16(t *testing.T) { + TfheScalarMul(t, FheUint16) +} + +func TestTfheScalarMul32(t *testing.T) { + TfheScalarMul(t, FheUint32) +} + func TestTfheBitAnd8(t *testing.T) { TfheBitAnd(t, FheUint8) } @@ -689,6 +1217,54 @@ func TestTfheBitXor32(t *testing.T) { TfheBitXor(t, FheUint32) } +func TestTfheShl8(t *testing.T) { + TfheShl(t, FheUint8) +} + +func TestTfheShl16(t *testing.T) { + TfheShl(t, FheUint16) +} + +func TestTfheShl32(t *testing.T) { + TfheShl(t, FheUint32) +} + +func TestTfheScalarShl8(t *testing.T) { + TfheScalarShl(t, FheUint8) +} + +func TestTfheScalarShl16(t *testing.T) { + TfheScalarShl(t, FheUint16) +} + +func TestTfheScalarShl32(t *testing.T) { + TfheScalarShl(t, FheUint32) +} + +func TestTfheShr8(t *testing.T) { + TfheShr(t, FheUint8) +} + +func TestTfheShr16(t *testing.T) { + TfheShr(t, FheUint16) +} + +func TestTfheShr32(t *testing.T) { + TfheShr(t, FheUint32) +} + +func TestTfheScalarShr8(t *testing.T) { + TfheScalarShr(t, FheUint8) +} + +func TestTfheScalarShr16(t *testing.T) { + TfheScalarShr(t, FheUint16) +} + +func TestTfheScalarShr32(t *testing.T) { + TfheScalarShr(t, FheUint32) +} + func TestTfheEq8(t *testing.T) { TfheEq(t, FheUint8) } @@ -701,6 +1277,42 @@ func TestTfheEq32(t *testing.T) { TfheEq(t, FheUint32) } +func TestTfheScalarEq8(t *testing.T) { + TfheScalarEq(t, FheUint8) +} + +func TestTfheScalarEq16(t *testing.T) { + TfheScalarEq(t, FheUint16) +} + +func TestTfheScalarEq32(t *testing.T) { + TfheScalarEq(t, FheUint32) +} + +func TestTfheNe8(t *testing.T) { + TfheNe(t, FheUint8) +} + +func TestTfheNe16(t *testing.T) { + TfheNe(t, FheUint16) +} + +func TestTfheNe32(t *testing.T) { + TfheNe(t, FheUint32) +} + +func TestTfheScalarNe8(t *testing.T) { + TfheScalarNe(t, FheUint8) +} + +func TestTfheScalarNe16(t *testing.T) { + TfheScalarNe(t, FheUint16) +} + +func TestTfheScalarNe32(t *testing.T) { + TfheScalarNe(t, FheUint32) +} + func TestTfheGe8(t *testing.T) { TfheGe(t, FheUint8) } @@ -713,6 +1325,18 @@ func TestTfheGe32(t *testing.T) { TfheGe(t, FheUint32) } +func TestTfheScalarGe8(t *testing.T) { + TfheScalarGe(t, FheUint8) +} + +func TestTfheScalarGe16(t *testing.T) { + TfheScalarGe(t, FheUint16) +} + +func TestTfheScalarGe32(t *testing.T) { + TfheScalarGe(t, FheUint32) +} + func TestTfheGt8(t *testing.T) { TfheGt(t, FheUint8) } @@ -725,27 +1349,131 @@ func TestTfheGt32(t *testing.T) { TfheGt(t, FheUint32) } -func TestTfheLte8(t *testing.T) { - TfheLte(t, FheUint8) +func TestTfheScalarGt8(t *testing.T) { + TfheScalarGt(t, FheUint8) +} + +func TestTfheScalarGt16(t *testing.T) { + TfheScalarGt(t, FheUint16) } -func TestTfheLte16(t *testing.T) { - TfheLte(t, FheUint16) +func TestTfheScalarGt32(t *testing.T) { + TfheScalarGt(t, FheUint32) } -func TestTfheLte32(t *testing.T) { - TfheLte(t, FheUint32) +func TestTfheLe8(t *testing.T) { + TfheLe(t, FheUint8) +} + +func TestTfheLe16(t *testing.T) { + TfheLe(t, FheUint16) +} + +func TestTfheLe32(t *testing.T) { + TfheLe(t, FheUint32) +} + +func TestTfheScalarLe8(t *testing.T) { + TfheScalarLe(t, FheUint8) +} + +func TestTfheScalarLe16(t *testing.T) { + TfheScalarLe(t, FheUint16) +} + +func TestTfheScalarLe32(t *testing.T) { + TfheScalarLe(t, FheUint32) } func TestTfheLt8(t *testing.T) { - TfheLte(t, FheUint8) + TfheLt(t, FheUint8) } func TestTfheLt16(t *testing.T) { - TfheLte(t, FheUint16) + TfheLt(t, FheUint16) } func TestTfheLt32(t *testing.T) { - TfheLte(t, FheUint32) + TfheLt(t, FheUint32) +} + +func TestTfheScalarLt8(t *testing.T) { + TfheScalarLt(t, FheUint8) +} + +func TestTfheScalarLt16(t *testing.T) { + TfheScalarLt(t, FheUint16) +} + +func TestTfheScalarLt32(t *testing.T) { + TfheScalarLt(t, FheUint32) +} + +func TestTfheMin8(t *testing.T) { + TfheMin(t, FheUint8) +} + +func TestTfheMin16(t *testing.T) { + TfheMin(t, FheUint16) +} +func TestTfheMin32(t *testing.T) { + TfheMin(t, FheUint32) +} + +func TestTfheScalarMin8(t *testing.T) { + TfheScalarMin(t, FheUint8) +} + +func TestTfheScalarMin16(t *testing.T) { + TfheScalarMin(t, FheUint16) +} + +func TestTfheScalarMin32(t *testing.T) { + TfheScalarMin(t, FheUint32) +} + +func TestTfheMax8(t *testing.T) { + TfheMax(t, FheUint8) +} + +func TestTfheMax16(t *testing.T) { + TfheMax(t, FheUint16) +} +func TestTfheMax32(t *testing.T) { + TfheMax(t, FheUint32) +} + +func TestTfheScalarMax8(t *testing.T) { + TfheScalarMax(t, FheUint8) +} + +func TestTfheScalarMax16(t *testing.T) { + TfheScalarMax(t, FheUint16) +} + +func TestTfheScalarMax32(t *testing.T) { + TfheScalarMax(t, FheUint32) +} + +func TestTfheNeg8(t *testing.T) { + TfheNeg(t, FheUint8) +} + +func TestTfheNeg16(t *testing.T) { + TfheNeg(t, FheUint16) +} +func TestTfheNeg32(t *testing.T) { + TfheNeg(t, FheUint32) +} + +func TestTfheNot8(t *testing.T) { + TfheNot(t, FheUint8) +} + +func TestTfheNot16(t *testing.T) { + TfheNot(t, FheUint16) +} +func TestTfheNot32(t *testing.T) { + TfheNot(t, FheUint32) } func TestTfhe8Cast16(t *testing.T) { diff --git a/params/protocol_params.go b/params/protocol_params.go index 70ff92d97..edd2bd766 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -169,9 +169,18 @@ const ( FheUint8BitwiseGas uint64 = 2000 FheUint16BitwiseGas uint64 = FheUint8BitwiseGas * 2 FheUint32BitwiseGas uint64 = FheUint8BitwiseGas * 4 - FheUint8LteGas uint64 = 3300 - FheUint16LteGas uint64 = 5000 - FheUint32LteGas uint64 = 11000 + FheUint8ShiftGas uint64 = 1000 + FheUint16ShiftGas uint64 = FheUint8ShiftGas * 2 + FheUint32ShiftGas uint64 = FheUint8ShiftGas * 4 + FheUint8LeGas uint64 = 3300 + FheUint16LeGas uint64 = 5000 + FheUint32LeGas uint64 = 11000 + FheUint8MinMaxGas uint64 = 3000 + FheUint16MinMaxGas uint64 = FheUint8MinMaxGas * 2 + FheUint32MinMaxGas uint64 = FheUint8MinMaxGas * 4 + FheUint8NegNotGas uint64 = 500 + FheUint16NegNotGas uint64 = FheUint8NegNotGas * 2 + FheUint32NegNotGas uint64 = FheUint8NegNotGas * 4 // TODO: Cost will depend on the complexity of doing reencryption by the oracle. FheUint8ReencryptGas uint64 = 15000