From 062abc6d635294b8ff79023a1c7ae59f96163773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Thu, 4 Jan 2024 13:46:48 +0100 Subject: [PATCH] feat() add ifThenElse to the stack --- fhevm/contracts_test.go | 114 ++++++++++++++++++++++++++++++ fhevm/evm.go | 20 ++++++ fhevm/params.go | 6 ++ fhevm/precompiles.go | 57 +++++++++++++++ fhevm/tfhe.go | 153 ++++++++++++++++++++++++++++++++++++++++ fhevm/tfhe_test.go | 46 ++++++++++++ 6 files changed, 396 insertions(+) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 54da6dd..3ab6121 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -65,6 +65,14 @@ func toPrecompileInput(isScalar bool, hashes ...common.Hash) []byte { return ret } +func toPrecompileInputNoScalar(isScalar bool, hashes ...common.Hash) []byte { + ret := make([]byte, 0) + for _, hash := range hashes { + ret = append(ret, hash.Bytes()...) + } + return ret +} + var scalarBytePadding = make([]byte, 31) func toLibPrecompileInput(method string, isScalar bool, hashes ...common.Hash) []byte { @@ -1249,6 +1257,44 @@ func FheLibRandBounded(t *testing.T, fheUintType FheUintType, upperBound64 uint6 } } +func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { + var second, third uint64 + switch fheUintType { + case FheUint8: + second = 2 + third = 1 + case FheUint16: + second = 4283 + third = 1337 + case FheUint32: + second = 1333337 + third = 133337 + } + signature := "fheIfThenElse(uint256,uint256,uint256)" + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheUint8).getHash() + secondHash := verifyCiphertextInTestMemory(environment, second, depth, fheUintType).getHash() + thirdHash := verifyCiphertextInTestMemory(environment, third, depth, fheUintType).getHash() + input := toLibPrecompileInputNoScalar(signature, firstHash, secondHash, thirdHash) + out, err := FheLibRun(environment, addr, addr, input, readOnly) + if err != nil { + t.Fatalf("VALUE %v", len(input)) + // t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || condition == 1 && decrypted.Uint64() != second || condition == 0 && decrypted.Uint64() != third { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } +} + func LibTrivialEncrypt(t *testing.T, fheUintType FheUintType) { var value big.Int switch fheUintType { @@ -2352,6 +2398,44 @@ func FheNot(t *testing.T, fheUintType FheUintType, scalar bool) { } } + +func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + depth := 1 + environment := newTestEVMEnvironment() + environment.depth = depth + addr := common.Address{} + readOnly := false + conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, fheUintType).getHash() + lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).getHash() + + input1 := toPrecompileInputNoScalar(false, conditionHash, lhsHash, rhsHash) + out, err := fheIfThenElseRun(environment, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || condition == 1 && decrypted.Uint64() != lhs || condition == 0 && decrypted.Uint64() != rhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } +} + func Decrypt(t *testing.T, fheUintType FheUintType) { var value uint64 switch fheUintType { @@ -2627,6 +2711,21 @@ func TestFheLibRandBounded32(t *testing.T) { FheLibRandBounded(t, FheUint32, 32) } +func TestFheLibIfThenElse8(t *testing.T) { + FheLibIfThenElse(t, FheUint8, 1) + FheLibIfThenElse(t, FheUint8, 0) +} + +func TestFheLibIfThenElse16(t *testing.T) { + FheLibIfThenElse(t, FheUint16, 1) + FheLibIfThenElse(t, FheUint16, 0) +} + +func TestFheLibIfThenElse32(t *testing.T) { + FheLibIfThenElse(t, FheUint32, 1) + FheLibIfThenElse(t, FheUint32, 0) +} + func TestFheLibTrivialEncrypt8(t *testing.T) { LibTrivialEncrypt(t, FheUint8) } @@ -3079,6 +3178,21 @@ func TestFheNot32(t *testing.T) { FheNot(t, FheUint32, false) } +func TestFheIfThenElse8(t *testing.T) { + FheIfThenElse(t, FheUint8, 1) + FheIfThenElse(t, FheUint8, 0) +} + +func TestFheIfThenElse16(t *testing.T) { + FheIfThenElse(t, FheUint16, 1) + FheIfThenElse(t, FheUint16, 0) +} + +func TestFheIfThenElse32(t *testing.T) { + FheIfThenElse(t, FheUint32, 1) + FheIfThenElse(t, FheUint32, 0) +} + func TestFheScalarMax8(t *testing.T) { FheMax(t, FheUint8, true) } diff --git a/fhevm/evm.go b/fhevm/evm.go index 12822fd..c6edf3f 100644 --- a/fhevm/evm.go +++ b/fhevm/evm.go @@ -73,6 +73,26 @@ func get2VerifiedOperands(environment EVMEnvironment, input []byte) (lhs *verifi return } +func get3VerifiedOperands(environment EVMEnvironment, input []byte) (first *verifiedCiphertext, second *verifiedCiphertext, third *verifiedCiphertext, err error) { + if len(input) != 96 { + return nil, nil, nil, errors.New("input needs to contain three 256-bit sized values") + } + first = getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + if first == nil { + return nil, nil, nil, errors.New("unverified ciphertext handle") + } + second = getVerifiedCiphertext(environment, common.BytesToHash(input[32:64])) + if second == nil { + return nil, nil, nil, errors.New("unverified ciphertext handle") + } + third = getVerifiedCiphertext(environment, common.BytesToHash(input[64:96])) + if third == nil { + return nil, nil, nil, errors.New("unverified ciphertext handle") + } + err = nil + return +} + func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *big.Int, err error) { if len(input) != 65 { return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") diff --git a/fhevm/params.go b/fhevm/params.go index 1147c15..022e9a7 100644 --- a/fhevm/params.go +++ b/fhevm/params.go @@ -70,6 +70,7 @@ type GasCosts struct { FheReencrypt map[FheUintType]uint64 FheTrivialEncrypt map[FheUintType]uint64 FheRand map[FheUintType]uint64 + FheIfThenElse map[FheUintType]uint64 FheVerify map[FheUintType]uint64 FheOptRequire map[FheUintType]uint64 FheOptRequireBitAnd map[FheUintType]uint64 @@ -150,6 +151,11 @@ func DefaultGasCosts() GasCosts { FheUint16: EvmNetSstoreInitGas + 2000, FheUint32: EvmNetSstoreInitGas + 3000, }, + FheIfThenElse: map[FheUintType]uint64{ + FheUint8: 61000, + FheUint16: 83000, + FheUint32: 109000, + }, // TODO: As of now, only support FheUint8. All optimistic require predicates are // downcast to FheUint8 at the solidity level. Eventually move to ebool. // If there is at least one optimistic require, we need to decrypt it as it was a normal FHE require. diff --git a/fhevm/precompiles.go b/fhevm/precompiles.go index 1884224..8475fcd 100644 --- a/fhevm/precompiles.go +++ b/fhevm/precompiles.go @@ -57,6 +57,7 @@ var signatureFheBitOr = makeKeccakSignature("fheBitOr(uint256,uint256,bytes1)") var signatureFheBitXor = makeKeccakSignature("fheBitXor(uint256,uint256,bytes1)") var signatureFheRand = makeKeccakSignature("fheRand(bytes1)") var signatureFheRandBounded = makeKeccakSignature("fheRandBounded(uint256,bytes1)") +var signatureFheIfThenElse = makeKeccakSignature("fheIfThenElse(uint256,uint256,uint256)") var signatureVerifyCiphertext = makeKeccakSignature("verifyCiphertext(bytes)") var signatureReencrypt = makeKeccakSignature("reencrypt(uint256,uint256)") var signatureOptimisticRequire = makeKeccakSignature("optimisticRequire(uint256)") @@ -149,6 +150,9 @@ func FheLibRequiredGas(environment EVMEnvironment, input []byte) uint64 { case signatureFheRandBounded: bwCompatBytes := input[4:minInt(37, len(input))] return fheRandBoundedRequiredGas(environment, bwCompatBytes) + case signatureFheIfThenElse: + bwCompatBytes := input[4:minInt(100, len(input))] + return fheIfThenElseRequiredGas(environment, bwCompatBytes) case signatureVerifyCiphertext: bwCompatBytes := input[4:] return verifyCiphertextRequiredGas(environment, bwCompatBytes) @@ -261,6 +265,9 @@ func FheLibRun(environment EVMEnvironment, caller common.Address, addr common.Ad case signatureFheRandBounded: bwCompatBytes := input[4:minInt(37, len(input))] return fheRandBoundedRun(environment, caller, addr, bwCompatBytes, readOnly) + case signatureFheIfThenElse: + bwCompatBytes := input[4:minInt(100, len(input))] + return fheIfThenElseRun(environment, caller, addr, bwCompatBytes, readOnly) case signatureVerifyCiphertext: // first 32 bytes of the payload is offset, then 32 bytes are size of byte array if len(input) <= 68 { @@ -613,6 +620,24 @@ func fheRandBoundedRequiredGas(environment EVMEnvironment, input []byte) uint64 return environment.FhevmParams().GasCosts.FheRand[randType] } +func fheIfThenElseRequiredGas(environment EVMEnvironment, input []byte) uint64 { + logger := environment.GetLogger() + first, second, third, err := get3VerifiedOperands(environment, input) + if err != nil { + logger.Error("IfThenElse op RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if first.ciphertext.fheUintType != FheUint8 { + logger.Error("IfThenElse op RequiredGas() invalid type for condition", "first", first.ciphertext.fheUintType) + return 0 + } + if second.ciphertext.fheUintType != third.ciphertext.fheUintType { + logger.Error("IfThenElse op RequiredGas() operand type mismatch", "second", second.ciphertext.fheUintType, "third", third.ciphertext.fheUintType) + return 0 + } + return environment.FhevmParams().GasCosts.FheIfThenElse[second.ciphertext.fheUintType] +} + func verifyCiphertextRequiredGas(environment EVMEnvironment, input []byte) uint64 { if len(input) <= 1 { environment.GetLogger().Error( @@ -1930,6 +1955,38 @@ func fheRandBoundedRun(environment EVMEnvironment, caller common.Address, addr c return generateRandom(environment, caller, randType, &bound64) } + +func fheIfThenElseRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := environment.GetLogger() + first, second, third, err := get3VerifiedOperands(environment, input) + if err != nil { + logger.Error("fheIfThenElse inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + + if second.ciphertext.fheUintType != third.ciphertext.fheUintType { + msg := "fheIfThenElse operand type mismatch" + logger.Error(msg, "second", second.ciphertext.fheUintType, "third", third.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, second.ciphertext.fheUintType), nil + } + + result, err := first.ciphertext.ifThenElse(second.ciphertext, third.ciphertext) + if err != nil { + logger.Error("fheIfThenElse failed", "err", err) + return nil, err + } + importCiphertext(environment, result) + + resultHash := result.getHash() + logger.Info("fheIfThenElse success", "first", first.ciphertext.getHash().Hex(), "second", second.ciphertext.getHash().Hex(), "third", third.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + func verifyCiphertextRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := environment.GetLogger() if len(input) <= 1 { diff --git a/fhevm/tfhe.go b/fhevm/tfhe.go index ae63365..3de5a33 100644 --- a/fhevm/tfhe.go +++ b/fhevm/tfhe.go @@ -1296,6 +1296,39 @@ void* not_fhe_uint32(void* ct, void* sks) { return result; } +void* if_then_else_fhe_uint8(void* condition, void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_if_then_else(condition, ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* if_then_else_fhe_uint16(void* condition, void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_if_then_else(condition, ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* if_then_else_fhe_uint32(void* condition, void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_if_then_else(condition, ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + int decrypt_fhe_uint8(void* cks, void* ct, uint8_t* res) { *res = 0; @@ -1966,6 +1999,113 @@ func (lhs *tfheCiphertext) executeBinaryCiphertextOperation(rhs *tfheCiphertext, return res, nil } + +func (first *tfheCiphertext) executeTernaryCiphertextOperation(lhs *tfheCiphertext, rhs *tfheCiphertext, + op8 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, + op16 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, + op32 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer) (*tfheCiphertext, error) { + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("ternary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + res_ser := &C.Buffer{} + switch lhs.fheUintType { + case FheUint8: + lhs_ptr := C.deserialize_fhe_uint8(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("8 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint8(toBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint8(lhs_ptr) + return nil, errors.New("8 bit binary op deserialization failed") + } + first_ptr := C.deserialize_fhe_uint8(toBufferView((first.serialization))) + if first_ptr == nil { + C.destroy_fhe_uint8(lhs_ptr) + C.destroy_fhe_uint8(rhs_ptr) + return nil, errors.New("8 bit binary op deserialization failed") + } + res_ptr := op8(first_ptr, lhs_ptr, rhs_ptr) + C.destroy_fhe_uint8(lhs_ptr) + C.destroy_fhe_uint8(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("8 bit binary op failed") + } + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint8(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) + case FheUint16: + lhs_ptr := C.deserialize_fhe_uint16(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("16 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint16(toBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint16(lhs_ptr) + return nil, errors.New("16 bit binary op deserialization failed") + } + first_ptr := C.deserialize_fhe_uint8(toBufferView((first.serialization))) + if first_ptr == nil { + C.destroy_fhe_uint8(lhs_ptr) + C.destroy_fhe_uint8(rhs_ptr) + return nil, errors.New("8 bit binary op deserialization failed") + } + res_ptr := op16(first_ptr, lhs_ptr, rhs_ptr) + C.destroy_fhe_uint16(lhs_ptr) + C.destroy_fhe_uint16(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("16 bit binary op failed") + } + ret := C.serialize_fhe_uint16(res_ptr, res_ser) + C.destroy_fhe_uint16(res_ptr) + if ret != 0 { + return nil, errors.New("16 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) + case FheUint32: + lhs_ptr := C.deserialize_fhe_uint32(toBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("32 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint32(toBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint32(lhs_ptr) + return nil, errors.New("32 bit binary op deserialization failed") + } + first_ptr := C.deserialize_fhe_uint8(toBufferView((first.serialization))) + if first_ptr == nil { + C.destroy_fhe_uint8(lhs_ptr) + C.destroy_fhe_uint8(rhs_ptr) + return nil, errors.New("8 bit binary op deserialization failed") + } + res_ptr := op32(first_ptr, lhs_ptr, rhs_ptr) + C.destroy_fhe_uint32(lhs_ptr) + C.destroy_fhe_uint32(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("32 bit binary op failed") + } + ret := C.serialize_fhe_uint32(res_ptr, res_ser) + C.destroy_fhe_uint32(res_ptr) + if ret != 0 { + return nil, errors.New("32 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_buffer(res_ser) + default: + panic("binary op unexpected ciphertext type") + } + res.computeHash() + return res, nil +} + func (lhs *tfheCiphertext) executeBinaryScalarOperation(rhs uint64, op8 func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer, op16 func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer, @@ -2464,6 +2604,19 @@ func (lhs *tfheCiphertext) not() (*tfheCiphertext, error) { }) } +func (condition *tfheCiphertext) ifThenElse(lhs *tfheCiphertext, rhs *tfheCiphertext) (*tfheCiphertext, error) { + return condition.executeTernaryCiphertextOperation(lhs, rhs, + func(condition unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.if_then_else_fhe_uint8(condition, lhs, rhs, sks) + }, + func(condition unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.if_then_else_fhe_uint16(condition, lhs, rhs, sks) + }, + func(condition unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.if_then_else_fhe_uint32(condition, lhs, rhs, sks) + }) +} + func (ct *tfheCiphertext) castTo(castToType FheUintType) (*tfheCiphertext, error) { if ct.fheUintType == castToType { return nil, errors.New("casting to same type is not supported") diff --git a/fhevm/tfhe_test.go b/fhevm/tfhe_test.go index 10968b2..9b2f1ae 100644 --- a/fhevm/tfhe_test.go +++ b/fhevm/tfhe_test.go @@ -1033,6 +1033,41 @@ func TfheNot(t *testing.T, fheUintType FheUintType) { } } +func TfheIfThenElse(t *testing.T, fheUintType FheUintType) { + var condition, condition2, a, b big.Int + condition.SetUint64(1) + condition2.SetUint64(0) + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctCondition := new(tfheCiphertext) + ctCondition.encrypt(condition, fheUintType) + ctCondition2 := new(tfheCiphertext) + ctCondition2.encrypt(condition2, fheUintType) + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes1, _ := ctCondition.ifThenElse(ctA, ctB) + ctRes2, _ := ctCondition2.ifThenElse(ctA, ctB) + res1, err1 := ctRes1.decrypt() + res2, err2 := ctRes2.decrypt() + if err1 != nil || res1.Uint64() != a.Uint64() { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if err2 != nil || res2.Uint64() != b.Uint64() { + t.Fatalf("%d != %d", 0, res2.Uint64()) + } +} + func TfheCast(t *testing.T, fheUintTypeFrom FheUintType, fheUintTypeTo FheUintType) { var a big.Int switch fheUintTypeFrom { @@ -1558,6 +1593,17 @@ func TestTfheNot32(t *testing.T) { TfheNot(t, FheUint32) } +func TestTfheIfThenElse8(t *testing.T) { + TfheIfThenElse(t, FheUint8) +} + +func TestTfheIfThenElse16(t *testing.T) { + TfheIfThenElse(t, FheUint16) +} +func TestTfheIfThenElse32(t *testing.T) { + TfheIfThenElse(t, FheUint32) +} + func TestTfhe8Cast16(t *testing.T) { TfheCast(t, FheUint8, FheUint16) }