From 6f68c19638ab2d6c465d00ff2ba802f462a1c419 Mon Sep 17 00:00:00 2001 From: Petar Ivanov <29689712+dartdart26@users.noreply.github.com> Date: Thu, 22 Jun 2023 10:47:52 +0300 Subject: [PATCH 1/8] Return raw FHE public key if called directly If the fhePubKey precompile is called directly (e.g. from eth_call), return the raw key bytes. If called from the EVM (i.e. Solidity lib), return it as an EVM array that can be used in Solidity. --- core/vm/contracts.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index d274d5df2..56c1d6266 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -2064,7 +2064,12 @@ func (e *fhePubKey) Run(accessibleState PrecompileAccessibleState, caller common accessibleState.Interpreter().evm.Logger.Error(msg, "existing", existing.Hex(), "pksHash", pksHash.Hex()) return nil, errors.New(msg) } - return toEVMBytes(pksBytes), nil + // If we have a single byte with the value of 1, return as an EVM array. Otherwise, returh the raw bytes. + if len(input) == 1 && input[0] == 1 { + return toEVMBytes(pksBytes), nil + } else { + return pksBytes, nil + } } type trivialEncrypt struct{} From 1038691f506840797dd6b739a7e509c57a05eb3e Mon Sep 17 00:00:00 2001 From: Louis Tremblay Thibault Date: Thu, 22 Jun 2023 13:39:55 +0200 Subject: [PATCH 2/8] feat: add support for casting (#118) * feat(tfhe): add support for casting * feat: add `cast` precompile * feat(cast): add type validity check --- core/vm/contracts.go | 71 +++++++++++++++++++------ core/vm/tfhe.go | 107 ++++++++++++++++++++++++++++++++++++++ core/vm/tfhe_test.go | 67 +++++++++++++++++++++++- params/protocol_params.go | 2 + 4 files changed, 229 insertions(+), 18 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 56c1d6266..ff7c05be2 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -77,7 +77,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -106,7 +106,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -136,7 +136,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -166,7 +166,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -196,7 +196,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -1438,7 +1438,7 @@ func (e *verifyCiphertext) RequiredGas(accessibleState PrecompileAccessibleState func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger if len(input) <= 1 { - msg := "verifyCiphertext RequiredGas() input needs to contain a ciphertext and one byte for its type" + msg := "verifyCiphertext Run() input needs to contain a ciphertext and one byte for its type" logger.Error(msg, "len", len(input)) return nil, errors.New(msg) } @@ -2026,18 +2026,57 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add // return ctHash[:], nil // } -// type cast struct{} +type cast struct{} -// func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { -// return 0 -// } +func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + if len(input) != 33 { + accessibleState.Interpreter().evm.Logger.Error( + "cast RequiredGas() input needs to contain a ciphertext and one byte for its type", + "len", len(input)) + return 0 + } + return params.FheCastGas +} -// // Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API. -// func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { -// // var ctHandle = common.BytesToHash(input[0:31]) -// // var toType = input[32] -// return nil, nil -// } +// Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API. +func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + if len(input) != 33 { + msg := "cast Run() input needs to contain a ciphertext and one byte for its type" + logger.Error(msg, "len", len(input)) + return nil, errors.New(msg) + } + + ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) + if ct == nil { + logger.Error("cast input not verified") + return nil, errors.New("unverified ciphertext handle") + } + + castToType := fheUintType(input[32]) + if !castToType.isValid() { + logger.Error("invalid type to cast to") + return nil, errors.New("invalid type provided") + } + + res, err := ct.ciphertext.castTo(castToType) + if err != nil { + msg := "cast Run() error casting ciphertext to" + logger.Error(msg, "type", castToType) + return nil, errors.New(msg) + } + + resHash := res.getHash() + + importCiphertext(accessibleState, res) + if accessibleState.Interpreter().evm.Commit { + logger.Info("cast success", + "ctHash", resHash.Hex(), + ) + } + + return resHash.Bytes(), nil +} type faucet struct{} diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index df6242535..00ba1cde5 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -480,6 +480,66 @@ void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, assert(r == 0); } +void* cast_8_16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint16(ct, &result); + assert(r == 0); + return result; +} + +void* cast_8_32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint32(ct, &result); + assert(r == 0); + return result; +} + +void* cast_16_8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_cast_into_fhe_uint8(ct, &result); + assert(r == 0); + return result; +} + +void* cast_16_32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_cast_into_fhe_uint32(ct, &result); + assert(r == 0); + return result; +} + +void* cast_32_8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_cast_into_fhe_uint8(ct, &result); + assert(r == 0); + return result; +} + +void* cast_32_16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_cast_into_fhe_uint16(ct, &result); + assert(r == 0); + return result; +} + */ import "C" @@ -810,6 +870,49 @@ func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (ct *tfheCiphertext) castTo(castToType fheUintType) (*tfheCiphertext, error) { + if !ct.availableForOps() { + panic("cannot cast a non-initialized ciphertext") + } + + if ct.fheUintType == castToType { + return nil, errors.New("casting to same type is not supported") + } + + if !castToType.isValid() { + return nil, errors.New("invalid type to cast to") + } + + res := new(tfheCiphertext) + res.fheUintType = castToType + + switch ct.fheUintType { + case FheUint8: + switch castToType { + case FheUint16: + res.setPtr(C.cast_8_16(ct.ptr, sks)) + case FheUint32: + res.setPtr(C.cast_8_32(ct.ptr, sks)) + } + case FheUint16: + switch castToType { + case FheUint8: + res.setPtr(C.cast_16_8(ct.ptr, sks)) + case FheUint32: + res.setPtr(C.cast_16_32(ct.ptr, sks)) + } + case FheUint32: + switch castToType { + case FheUint8: + res.setPtr(C.cast_32_8(ct.ptr, sks)) + case FheUint16: + res.setPtr(C.cast_32_16(ct.ptr, sks)) + } + } + + return res, nil +} + func (ct *tfheCiphertext) decrypt() big.Int { if !ct.availableForOps() { panic("cannot decrypt a null ciphertext") @@ -869,6 +972,10 @@ func (ct *tfheCiphertext) initialized() bool { return (ct.ptr != nil) } +func (t *fheUintType) isValid() bool { + return (*t <= 2) +} + // Used for testing. func encryptAndSerializeCompact(value uint32, fheUintType fheUintType) []byte { out := &C.Buffer{} diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index e7d05ac6d..58c02d709 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -18,6 +18,7 @@ package vm import ( "bytes" + "math" "math/big" "testing" ) @@ -301,9 +302,9 @@ func TfheLt(t *testing.T, fheUintType fheUintType) { b.SetUint64(133337) } ctA := new(tfheCiphertext) - ctA.encrypt(a, FheUint8) + ctA.encrypt(a, fheUintType) ctB := new(tfheCiphertext) - ctB.encrypt(b, FheUint8) + ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.lte(ctB) ctRes2, _ := ctB.lte(ctA) res1 := ctRes1.decrypt() @@ -316,6 +317,44 @@ func TfheLt(t *testing.T, fheUintType fheUintType) { } } +func TfheCast(t *testing.T, fheUintTypeFrom fheUintType, fheUintTypeTo fheUintType) { + var a big.Int + switch fheUintTypeFrom { + case FheUint8: + a.SetUint64(2) + case FheUint16: + a.SetUint64(4283) + case FheUint32: + a.SetUint64(1333337) + } + + var modulus uint64 + switch fheUintTypeTo { + case FheUint8: + modulus = uint64(math.Pow(2, 8)) + case FheUint16: + modulus = uint64(math.Pow(2, 16)) + case FheUint32: + modulus = uint64(math.Pow(2, 32)) + } + + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintTypeFrom) + ctRes, err := ctA.castTo(fheUintTypeTo) + if err != nil { + t.Fatal(err) + } + + if ctRes.fheUintType != fheUintTypeTo { + t.Fatalf("type %d != type %d", ctA.fheUintType, fheUintTypeTo) + } + res := ctRes.decrypt() + expected := a.Uint64() % modulus + if res.Uint64() != expected { + t.Fatalf("%d != %d", res.Uint64(), expected) + } +} + func TestTfheEncryptDecrypt8(t *testing.T) { TfheEncryptDecrypt(t, FheUint8) } @@ -470,3 +509,27 @@ func TestTfheLte32(t *testing.T) { func TestTfheLt32(t *testing.T) { TfheLte(t, FheUint32) } + +func TestTfhe8Cast16(t *testing.T) { + TfheCast(t, FheUint8, FheUint16) +} + +func TestTfhe8Cast32(t *testing.T) { + TfheCast(t, FheUint8, FheUint32) +} + +func TestTfhe16Cast8(t *testing.T) { + TfheCast(t, FheUint16, FheUint8) +} + +func TestTfhe16Cast32(t *testing.T) { + TfheCast(t, FheUint16, FheUint32) +} + +func TestTfhe32Cast8(t *testing.T) { + TfheCast(t, FheUint16, FheUint8) +} + +func TestTfhe32Cast16(t *testing.T) { + TfheCast(t, FheUint16, FheUint8) +} diff --git a/params/protocol_params.go b/params/protocol_params.go index 4be3bb685..73bc68de5 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -206,6 +206,8 @@ const ( FheUint16ProtectedStorageSloadGas uint64 = FheUint8ProtectedStorageSloadGas * 2 FheUint32ProtectedStorageSloadGas uint64 = FheUint16ProtectedStorageSloadGas * 4 + FheCastGas uint64 = 100 + FhePubKeyGas uint64 = 2 FheUint8TrivialEncryptGas uint64 = 100 From 73cae1cd775b670fc70c5f67d10eae000eaa8aaa Mon Sep 17 00:00:00 2001 From: Levent Demir Date: Fri, 23 Jun 2023 14:59:01 +0200 Subject: [PATCH 3/8] build: update tfhe-rs version to 0.3.0-beta.0 --- install_thfe_rs_api.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install_thfe_rs_api.sh b/install_thfe_rs_api.sh index 8f3b0768f..2b9256dbc 100755 --- a/install_thfe_rs_api.sh +++ b/install_thfe_rs_api.sh @@ -1,7 +1,7 @@ #!/bin/bash git clone https://github.com/zama-ai/tfhe-rs.git -git checkout 1d817c45d5234bcf33638406191b656998b30c2a +git checkout 0.3.0-beta.0 mkdir -p core/vm/lib cd tfhe-rs make build_c_api From 4c688232045e5e0f6d3b9b19ff9f67b5d60deeb3 Mon Sep 17 00:00:00 2001 From: Levent Demir Date: Fri, 23 Jun 2023 15:26:59 +0200 Subject: [PATCH 4/8] chore(ci): bump tfhe-rs verison in ci workflow --- .github/workflows/publish_geth_testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish_geth_testing.yml b/.github/workflows/publish_geth_testing.yml index b90ad79b0..3ab226514 100644 --- a/.github/workflows/publish_geth_testing.yml +++ b/.github/workflows/publish_geth_testing.yml @@ -15,7 +15,7 @@ jobs: uses: actions/checkout@v3 with: repository: zama-ai/tfhe-rs - ref: 1d817c45d5234bcf33638406191b656998b30c2a + ref: 0.3.0-beta.0 path: tfhe-rs - name: Checkout zbc-fhe-tool From c6b8db0dfc73462629647b6a89c444c7ad1531fe Mon Sep 17 00:00:00 2001 From: Levent Demir Date: Fri, 23 Jun 2023 16:20:30 +0200 Subject: [PATCH 5/8] chore(build c api): call deterministic build of C api as in the workflow (ci test) --- install_thfe_rs_api.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install_thfe_rs_api.sh b/install_thfe_rs_api.sh index 2b9256dbc..effe423cd 100755 --- a/install_thfe_rs_api.sh +++ b/install_thfe_rs_api.sh @@ -4,6 +4,6 @@ git clone https://github.com/zama-ai/tfhe-rs.git git checkout 0.3.0-beta.0 mkdir -p core/vm/lib cd tfhe-rs -make build_c_api +make build_c_api_experimental_deterministic_fft cp target/release/libtfhe.* ../core/vm/lib cp target/release/tfhe.h ../core/vm From e1a615660db757c6d775a62055dacf28bee62ae6 Mon Sep 17 00:00:00 2001 From: Louis Tremblay Thibault Date: Mon, 26 Jun 2023 10:49:52 +0200 Subject: [PATCH 6/8] Add all available FHE ops (#120) * feat(tfhe): add support for casting * feat: add `cast` precompile * feat(tfhe): add `&`, `|`, `^`, `==`, `>`, `>=` * feat(precompiles): add missing ops * feat(precompiles): add tests * fix: add precompile contract addresses * fit(precompiles): change precompile order --- core/vm/contracts.go | 324 ++++++++++++++++++++++++++++++++- core/vm/contracts_test.go | 368 +++++++++++++++++++++++++++++++++++--- core/vm/tfhe.go | 330 ++++++++++++++++++++++++++++++++++ core/vm/tfhe_test.go | 276 +++++++++++++++++++++++++--- params/protocol_params.go | 21 ++- 5 files changed, 1268 insertions(+), 51 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index ff7c05be2..c24ac45ae 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -79,6 +79,12 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{75}): &optimisticRequire{}, common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, + common.BytesToAddress([]byte{78}): &fheBitAnd{}, + common.BytesToAddress([]byte{79}): &fheBitOr{}, + common.BytesToAddress([]byte{80}): &fheBitXor{}, + common.BytesToAddress([]byte{81}): &fheEq{}, + common.BytesToAddress([]byte{82}): &fheGe{}, + common.BytesToAddress([]byte{83}): &fheGt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -108,6 +114,12 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{75}): &optimisticRequire{}, common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, + common.BytesToAddress([]byte{78}): &fheBitAnd{}, + common.BytesToAddress([]byte{79}): &fheBitOr{}, + common.BytesToAddress([]byte{80}): &fheBitXor{}, + common.BytesToAddress([]byte{81}): &fheEq{}, + common.BytesToAddress([]byte{82}): &fheGe{}, + common.BytesToAddress([]byte{83}): &fheGt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -138,6 +150,12 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{75}): &optimisticRequire{}, common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, + common.BytesToAddress([]byte{78}): &fheBitAnd{}, + common.BytesToAddress([]byte{79}): &fheBitOr{}, + common.BytesToAddress([]byte{80}): &fheBitXor{}, + common.BytesToAddress([]byte{81}): &fheEq{}, + common.BytesToAddress([]byte{82}): &fheGe{}, + common.BytesToAddress([]byte{83}): &fheGt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -168,6 +186,12 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{75}): &optimisticRequire{}, common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, + common.BytesToAddress([]byte{78}): &fheBitAnd{}, + common.BytesToAddress([]byte{79}): &fheBitOr{}, + common.BytesToAddress([]byte{80}): &fheBitXor{}, + common.BytesToAddress([]byte{81}): &fheEq{}, + common.BytesToAddress([]byte{82}): &fheGe{}, + common.BytesToAddress([]byte{83}): &fheGt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -198,6 +222,12 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{75}): &optimisticRequire{}, common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, + common.BytesToAddress([]byte{78}): &fheBitAnd{}, + common.BytesToAddress([]byte{79}): &fheBitOr{}, + common.BytesToAddress([]byte{80}): &fheBitXor{}, + common.BytesToAddress([]byte{81}): &fheEq{}, + common.BytesToAddress([]byte{82}): &fheGe{}, + common.BytesToAddress([]byte{83}): &fheGt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -1305,6 +1335,12 @@ var fheAddSubGasCosts = map[fheUintType]uint64{ FheUint32: params.FheUint32AddSubGas, } +var fheBitwiseOpGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8BitwiseGas, + FheUint16: params.FheUint16BitwiseGas, + FheUint32: params.FheUint32BitwiseGas, +} + var fheMulGasCosts = map[fheUintType]uint64{ FheUint8: params.FheUint8MulGas, FheUint16: params.FheUint16MulGas, @@ -1754,11 +1790,11 @@ type fheLte struct{} func (e *fheLte) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { lhs, rhs, err := get2VerifiedOperands(accessibleState, input) if err != nil { - accessibleState.Interpreter().evm.Logger.Error("fheLt/Lte RequiredGas() inputs not verified", "err", err) + accessibleState.Interpreter().evm.Logger.Error("fheLte (comparison) RequiredGas() inputs not verified", "err", err) return 0 } if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { - accessibleState.Interpreter().evm.Logger.Error("fheLt/Lte RequiredGas() operand type mismatch", "lhs", + accessibleState.Interpreter().evm.Logger.Error("fheLte (comparison) RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) return 0 } @@ -1902,6 +1938,290 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad return ctHash[:], nil } +type fheBitAnd struct{} + +func (e *fheBitAnd) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("Bitwise op RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input)) + return 0 + } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + logger.Error("Bitwise op RequiredGas() operand type mismatch", "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return 0 + } + return fheBitwiseOpGasCosts[lhs.ciphertext.fheUintType] +} + +func (e *fheBitAnd) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheBitAnd inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheBitAnd operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.bitand(rhs.ciphertext) + if err != nil { + logger.Error("fheBitAnd failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/bitand_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheBitAnd failed to write /tmp/bitand_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheBitAnd success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + +type fheBitOr struct{} + +func (e *fheBitOr) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of bitAnd, because bitwise op costs are currently the same. + and := fheBitAnd{} + return and.RequiredGas(accessibleState, input) +} + +func (e *fheBitOr) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheBitOr inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheBitOr operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.bitor(rhs.ciphertext) + if err != nil { + logger.Error("fheBitOr failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/bitor_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheBitOr failed to write /tmp/bitor_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheBitOr success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + +type fheBitXor struct{} + +func (e *fheBitXor) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of bitAnd, because bitwise op costs are currently the same. + and := fheBitAnd{} + return and.RequiredGas(accessibleState, input) +} + +func (e *fheBitXor) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheBitXor inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheBitXor operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.bitxor(rhs.ciphertext) + if err != nil { + logger.Error("fheBitXor failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/bitxor_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheBitXor failed to write /tmp/bitxor_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheBitXor success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + +type fheEq struct{} + +func (e *fheEq) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of lte, because comparison costs are currently the same. + lte := fheLte{} + return lte.RequiredGas(accessibleState, input) +} + +func (e *fheEq) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheEq inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheEq operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.eq(rhs.ciphertext) + if err != nil { + logger.Error("fheEq failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/eq_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheEq failed to write /tmp/eq_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheEq success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + +type fheGe struct{} + +func (e *fheGe) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of lte, because comparison costs are currently the same. + lte := fheLte{} + return lte.RequiredGas(accessibleState, input) +} + +func (e *fheGe) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheGe inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheGe operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.ge(rhs.ciphertext) + if err != nil { + logger.Error("fheGe failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/ge_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheGe failed to write /tmp/ge_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheGt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + +type fheGt struct{} + +func (e *fheGt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + // Implement in terms of lte, because comparison costs are currently the same. + lte := fheLte{} + return lte.RequiredGas(accessibleState, input) +} + +func (e *fheGt) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + lhs, rhs, err := get2VerifiedOperands(accessibleState, input) + if err != nil { + logger.Error("fheGt inputs not verified", "err", err) + return nil, err + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + msg := "fheGt operand type mismatch" + logger.Error(msg, "lhs", lhs.ciphertext.fheUintType, "rhs", rhs.ciphertext.fheUintType) + return nil, errors.New(msg) + } + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + result, err := lhs.ciphertext.gt(rhs.ciphertext) + if err != nil { + logger.Error("fheGt failed", "err", err) + return nil, err + } + importCiphertext(accessibleState, result) + + // TODO: for testing + err = os.WriteFile("/tmp/gt_result", result.serialize(), 0644) + if err != nil { + logger.Error("fheGt failed to write /tmp/gt_result", "err", err) + return nil, err + } + + resultHash := result.getHash() + logger.Info("fheGt success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.ciphertext.getHash().Hex(), "result", resultHash.Hex()) + return resultHash[:], nil +} + type fheLt struct{} func (e *fheLt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index fe5467a92..5e1822907 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -665,6 +665,262 @@ func FheMul(t *testing.T, fheUintType fheUintType) { } } +func FheBitAnd(t *testing.T, fheUintType fheUintType) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs & rhs + c := &fheBitAnd{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + input := toPrecompileInput(lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheBitOr(t *testing.T, fheUintType fheUintType) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs | rhs + c := &fheBitOr{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + input := toPrecompileInput(lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheBitXor(t *testing.T, fheUintType fheUintType) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs ^ rhs + c := &fheBitXor{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + input := toPrecompileInput(lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheEq(t *testing.T, fheUintType fheUintType) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheEq{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + + // lhs == rhs + input1 := toPrecompileInput(lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } +} + +func FheGe(t *testing.T, fheUintType fheUintType) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheGe{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + + // lhs >= rhs + input1 := toPrecompileInput(lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + + // rhs >= lhs + input2 := toPrecompileInput(rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } +} + +func FheGt(t *testing.T, fheUintType fheUintType) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + + c := &fheGt{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + + // lhs > rhs + input1 := toPrecompileInput(lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted := res.ciphertext.decrypt() + if decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + + // rhs > lhs + input2 := toPrecompileInput(rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted = res.ciphertext.decrypt() + if decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } +} + func FheLte(t *testing.T, fheUintType fheUintType) { var lhs, rhs uint64 switch fheUintType { @@ -831,58 +1087,130 @@ func TestFheAdd8(t *testing.T) { FheAdd(t, FheUint8) } +func TestFheAdd16(t *testing.T) { + FheAdd(t, FheUint16) +} + +func TestFheAdd32(t *testing.T) { + FheAdd(t, FheUint32) +} + func TestFheSub8(t *testing.T) { FheSub(t, FheUint8) } +func TestFheSub16(t *testing.T) { + FheSub(t, FheUint16) +} + +func TestFheSub32(t *testing.T) { + FheSub(t, FheUint32) +} + func TestFheMul8(t *testing.T) { FheMul(t, FheUint8) } -func TestFheLte8(t *testing.T) { - FheLte(t, FheUint8) +func TestFheMul16(t *testing.T) { + FheMul(t, FheUint16) } -func TestFheLt8(t *testing.T) { - FheLt(t, FheUint8) +func TestFheMul32(t *testing.T) { + FheMul(t, FheUint32) } -func TestFheAdd16(t *testing.T) { - FheAdd(t, FheUint16) +func TestFheBitAnd8(t *testing.T) { + FheBitAnd(t, FheUint8) } -func TestFheSub16(t *testing.T) { - FheSub(t, FheUint16) +func TestFheBitAnd16(t *testing.T) { + FheBitAnd(t, FheUint16) } -func TestFheMul16(t *testing.T) { - FheMul(t, FheUint16) +func TestFheBitAnd32(t *testing.T) { + FheBitAnd(t, FheUint32) } -func TestFheLte16(t *testing.T) { - FheLte(t, FheUint16) +func TestFheBitOr8(t *testing.T) { + FheBitOr(t, FheUint8) } -func TestFheLt16(t *testing.T) { - FheLt(t, FheUint16) +func TestFheBitOr16(t *testing.T) { + FheBitOr(t, FheUint16) } -func TestFheAdd32(t *testing.T) { - FheAdd(t, FheUint32) +func TestFheBitOr32(t *testing.T) { + FheBitOr(t, FheUint32) } -func TestFheSub32(t *testing.T) { - FheSub(t, FheUint32) +func TestFheBitXor8(t *testing.T) { + FheBitXor(t, FheUint8) } -func TestFheMul32(t *testing.T) { - FheMul(t, FheUint32) +func TestFheBitXor16(t *testing.T) { + FheBitXor(t, FheUint16) +} + +func TestFheBitXor32(t *testing.T) { + FheBitXor(t, FheUint32) +} + +func TestFheEq8(t *testing.T) { + FheEq(t, FheUint8) +} + +func TestFheEq16(t *testing.T) { + FheEq(t, FheUint16) +} + +func TestFheEq32(t *testing.T) { + FheEq(t, FheUint32) +} + +func TestFheGe8(t *testing.T) { + FheGe(t, FheUint8) +} + +func TestFheGe16(t *testing.T) { + FheGe(t, FheUint16) +} + +func TestFheGe32(t *testing.T) { + FheGe(t, FheUint32) +} + +func TestFheGt8(t *testing.T) { + FheGt(t, FheUint8) +} + +func TestFheGt16(t *testing.T) { + FheGt(t, FheUint16) +} + +func TestFheGt32(t *testing.T) { + FheGt(t, FheUint32) +} + +func TestFheLte8(t *testing.T) { + FheLte(t, FheUint8) +} + +func TestFheLte16(t *testing.T) { + FheLte(t, FheUint16) } func TestFheLte32(t *testing.T) { FheLte(t, FheUint32) } +func TestFheLt8(t *testing.T) { + FheLt(t, FheUint8) +} + +func TestFheLt16(t *testing.T) { + FheLt(t, FheUint16) +} + func TestFheLt32(t *testing.T) { FheLt(t, FheUint32) } diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index 00ba1cde5..0f7c95df7 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -279,6 +279,204 @@ void* mul_fhe_uint32(void* ct1, void* ct2, void* sks) return result; } +void* bitand_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_bitand(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitand_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_bitand(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitand_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_bitand(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitor_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_bitor(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitor_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_bitor(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitor_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_bitor(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitxor_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_bitxor(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitxor_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_bitxor(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* bitxor_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_bitxor(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* eq_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_eq(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* eq_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_eq(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* eq_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_eq(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* ge_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_ge(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* ge_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_ge(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* ge_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_ge(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* gt_fhe_uint8(void* ct1, void* ct2, void* sks) +{ + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_gt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* gt_fhe_uint16(void* ct1, void* ct2, void* sks) +{ + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_gt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* gt_fhe_uint32(void* ct1, void* ct2, void* sks) +{ + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_gt(ct1, ct2, &result); + assert(r == 0); + return result; +} + void* le_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -826,6 +1024,138 @@ func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (lhs *tfheCiphertext) bitand(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot bitwise AND on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.bitand_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.bitand_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.bitand_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) bitor(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot bitwise OR on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.bitor_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.bitor_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.bitor_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) bitxor(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot bitwise XOR on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.bitxor_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.bitxor_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.bitxor_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) eq(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot eq on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.eq_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.eq_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.eq_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) ge(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot ge on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.ge_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.ge_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.ge_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + +func (lhs *tfheCiphertext) gt(rhs *tfheCiphertext) (*tfheCiphertext, error) { + if !lhs.availableForOps() || !rhs.availableForOps() { + panic("cannot gt on a non-initialized ciphertext") + } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + res := new(tfheCiphertext) + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.gt_fhe_uint8(lhs.ptr, rhs.ptr, sks)) + case FheUint16: + res.setPtr(C.gt_fhe_uint16(lhs.ptr, rhs.ptr, sks)) + case FheUint32: + res.setPtr(C.gt_fhe_uint32(lhs.ptr, rhs.ptr, sks)) + } + return res, nil +} + func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot lte on a non-initialized ciphertext") diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index 58c02d709..5bc442356 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -259,6 +259,170 @@ func TfheMul(t *testing.T, fheUintType fheUintType) { } } +func TfheBitAnd(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := a.Uint64() & b.Uint64() + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.bitand(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheBitOr(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := a.Uint64() | b.Uint64() + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.bitor(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheBitXor(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(17) + } + expected := a.Uint64() ^ b.Uint64() + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.bitxor(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheEq(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(2) + case FheUint16: + a.SetUint64(169) + b.SetUint64(5) + case FheUint32: + a.SetUint64(137) + b.SetInt64(137) + } + var expected uint64 + expectedBool := a.Uint64() == b.Uint64() + if expectedBool { + expected = 1 + } else { + expected = 0 + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes, _ := ctA.eq(ctB) + res := ctRes.decrypt() + if res.Uint64() != expected { + t.Fatalf("%d != %d", expected, res.Uint64()) + } +} + +func TfheGe(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes1, _ := ctA.ge(ctB) + ctRes2, _ := ctB.ge(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if res2.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res2.Uint64()) + } +} + +func TfheGt(t *testing.T, fheUintType fheUintType) { + var a, b big.Int + switch fheUintType { + case FheUint8: + a.SetUint64(2) + b.SetUint64(1) + case FheUint16: + a.SetUint64(4283) + b.SetUint64(1337) + case FheUint32: + a.SetUint64(1333337) + b.SetUint64(133337) + } + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintType) + ctB := new(tfheCiphertext) + ctB.encrypt(b, fheUintType) + ctRes1, _ := ctA.gt(ctB) + ctRes2, _ := ctB.gt(ctA) + res1 := ctRes1.decrypt() + res2 := ctRes2.decrypt() + if res1.Uint64() != 1 { + t.Fatalf("%d != %d", 0, res1.Uint64()) + } + if res2.Uint64() != 0 { + t.Fatalf("%d != %d", 0, res2.Uint64()) + } +} + func TfheLte(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { @@ -455,57 +619,129 @@ func TestTfheAdd8(t *testing.T) { TfheAdd(t, FheUint8) } +func TestTfheAdd16(t *testing.T) { + TfheAdd(t, FheUint16) +} + +func TestTfheAdd32(t *testing.T) { + TfheAdd(t, FheUint32) +} + func TestTfheSub8(t *testing.T) { TfheSub(t, FheUint8) } +func TestTfheSub16(t *testing.T) { + TfheSub(t, FheUint16) +} + +func TestTfheSub32(t *testing.T) { + TfheSub(t, FheUint32) +} + func TestTfheMul8(t *testing.T) { TfheMul(t, FheUint8) } -func TestTfheLte8(t *testing.T) { - TfheLte(t, FheUint8) +func TestTfheMul16(t *testing.T) { + TfheMul(t, FheUint16) } -func TestTfheLt8(t *testing.T) { - TfheLte(t, FheUint8) +func TestTfheMul32(t *testing.T) { + TfheMul(t, FheUint32) } -func TestTfheAdd16(t *testing.T) { - TfheAdd(t, FheUint16) + +func TestTfheBitAnd8(t *testing.T) { + TfheBitAnd(t, FheUint8) } -func TestTfheSub16(t *testing.T) { - TfheSub(t, FheUint16) +func TestTfheBitAnd16(t *testing.T) { + TfheBitAnd(t, FheUint16) } -func TestTfheMul16(t *testing.T) { - TfheMul(t, FheUint16) +func TestTfheBitAnd32(t *testing.T) { + TfheBitAnd(t, FheUint32) } -func TestTfheLte16(t *testing.T) { - TfheLte(t, FheUint16) +func TestTfheBitOr8(t *testing.T) { + TfheBitOr(t, FheUint8) } -func TestTfheLt16(t *testing.T) { - TfheLte(t, FheUint16) +func TestTfheBitOr16(t *testing.T) { + TfheBitOr(t, FheUint16) } -func TestTfheAdd32(t *testing.T) { - TfheAdd(t, FheUint32) +func TestTfheBitOr32(t *testing.T) { + TfheBitOr(t, FheUint32) } -func TestTfheSub32(t *testing.T) { - TfheSub(t, FheUint32) +func TestTfheBitXor8(t *testing.T) { + TfheBitXor(t, FheUint8) } -func TestTfheMul32(t *testing.T) { - TfheMul(t, FheUint32) +func TestTfheBitXor16(t *testing.T) { + TfheBitXor(t, FheUint16) +} + +func TestTfheBitXor32(t *testing.T) { + TfheBitXor(t, FheUint32) +} + +func TestTfheEq8(t *testing.T) { + TfheEq(t, FheUint8) +} + +func TestTfheEq16(t *testing.T) { + TfheEq(t, FheUint16) +} + +func TestTfheEq32(t *testing.T) { + TfheEq(t, FheUint32) +} + +func TestTfheGe8(t *testing.T) { + TfheGe(t, FheUint8) +} + +func TestTfheGe16(t *testing.T) { + TfheGe(t, FheUint16) +} + +func TestTfheGe32(t *testing.T) { + TfheGe(t, FheUint32) +} + +func TestTfheGt8(t *testing.T) { + TfheGt(t, FheUint8) +} + +func TestTfheGt16(t *testing.T) { + TfheGt(t, FheUint16) +} + +func TestTfheGt32(t *testing.T) { + TfheGt(t, FheUint32) +} + +func TestTfheLte8(t *testing.T) { + TfheLte(t, FheUint8) +} + +func TestTfheLte16(t *testing.T) { + TfheLte(t, FheUint16) } func TestTfheLte32(t *testing.T) { TfheLte(t, FheUint32) } +func TestTfheLt8(t *testing.T) { + TfheLte(t, FheUint8) +} + +func TestTfheLt16(t *testing.T) { + TfheLte(t, FheUint16) +} func TestTfheLt32(t *testing.T) { TfheLte(t, FheUint32) } diff --git a/params/protocol_params.go b/params/protocol_params.go index 73bc68de5..70ff92d97 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -160,15 +160,18 @@ const ( RefundQuotientEIP3529 uint64 = 5 // FHE operation costs depend on tfhe-rs performance and hardware acceleration. These values will most certainly change. - FheUint8AddSubGas uint64 = 5000 - FheUint16AddSubGas uint64 = FheUint8AddSubGas * 2 - FheUint32AddSubGas uint64 = FheUint16AddSubGas * 4 - FheUint8MulGas uint64 = 9000 - FheUint16MulGas uint64 = FheUint8MulGas * 3 - FheUint32MulGas uint64 = FheUint16MulGas * 10 - FheUint8LteGas uint64 = 3300 - FheUint16LteGas uint64 = 5000 - FheUint32LteGas uint64 = 11000 + FheUint8AddSubGas uint64 = 5000 + FheUint16AddSubGas uint64 = FheUint8AddSubGas * 2 + FheUint32AddSubGas uint64 = FheUint16AddSubGas * 4 + FheUint8MulGas uint64 = 9000 + FheUint16MulGas uint64 = FheUint8MulGas * 3 + FheUint32MulGas uint64 = FheUint16MulGas * 10 + FheUint8BitwiseGas uint64 = 2000 + FheUint16BitwiseGas uint64 = FheUint8BitwiseGas * 2 + FheUint32BitwiseGas uint64 = FheUint8BitwiseGas * 4 + FheUint8LteGas uint64 = 3300 + FheUint16LteGas uint64 = 5000 + FheUint32LteGas uint64 = 11000 // TODO: Cost will depend on the complexity of doing reencryption by the oracle. FheUint8ReencryptGas uint64 = 15000 From 831a6f58b7c0345e9e0ae6f07d8c4747dbc15914 Mon Sep 17 00:00:00 2001 From: Levent Demir Date: Mon, 26 Jun 2023 11:52:11 +0200 Subject: [PATCH 7/8] ci: fix the failing error in the ci --- core/vm/tfhe_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index 58c02d709..5ec189898 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -148,7 +148,9 @@ func TfheTrivialSerializeDeserialize(t *testing.T, fheUintType fheUintType) { func TfheDeserializeFailure(t *testing.T, fheUintType fheUintType) { ct := new(tfheCiphertext) - err := ct.deserialize(make([]byte, 10), fheUintType) + input := make([]byte, 1) + input[0] = 42 + err := ct.deserialize(input, fheUintType) if err == nil { t.Fatalf("deserialization must have failed") } From 88f7c50ae5875f3cd4fd7d56bf319c331bd40087 Mon Sep 17 00:00:00 2001 From: Petar Ivanov <29689712+dartdart26@users.noreply.github.com> Date: Mon, 26 Jun 2023 14:13:31 +0300 Subject: [PATCH 8/8] Optimize gas estimation running time (#124) Optimize gas estimation running time with two approaches: 1. Don't call any storage methods on SSTORE during gas estimation. Instead, return early if we are in gas estimation (Commit = false). For that to work, make sure that on SLOAD, we first lookup the ciphertext hash in memory and, if found, use it without going to storage. That improves performance by not doing heavy work of splitting a ciphertext to 32 byte slots and persisting it to a memory-backed store for no reason. 2. When importing a random ciphertext during gas estimation, do not use public key encryption. Instead, just genearate an unique hash for the random ciphertext by utilizing a counter that is part of the EVM itself. That improves performance by not doing public key encryption and serialization just to compute a ciphertext hash (handle). --- core/vm/contracts.go | 16 +++++++++++++--- core/vm/evm.go | 18 +++++++++++------- core/vm/instructions.go | 7 ++++++- core/vm/tfhe.go | 12 +++++++----- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index c24ac45ae..b9d144d3d 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -1306,11 +1306,16 @@ func importCiphertext(accessibleState PrecompileAccessibleState, ct *tfheCiphert // Used when we want to skip FHE computation, e.g. gas estimation. func importRandomCiphertext(accessibleState PrecompileAccessibleState, t fheUintType) []byte { + nextCtHash := &accessibleState.Interpreter().evm.nextCiphertextHashOnGasEst + ctHashBytes := crypto.Keccak256(nextCtHash.Bytes()) + handle := common.BytesToHash(ctHashBytes) ct := new(tfheCiphertext) - ct.encrypt(*big.NewInt(0), t) + ct.fheUintType = t + ct.hash = &handle importCiphertext(accessibleState, ct) - ctHash := ct.getHash() - return ctHash[:] + temp := nextCtHash.Clone() + nextCtHash.Add(temp, uint256.NewInt(1)) + return ct.getHash().Bytes() } func get2VerifiedOperands(accessibleState PrecompileAccessibleState, input []byte) (lhs *verifiedCiphertext, rhs *verifiedCiphertext, err error) { @@ -1482,6 +1487,11 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller ctBytes := input[:len(input)-1] ctType := fheUintType(input[len(input)-1]) + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { + return importRandomCiphertext(accessibleState, ctType), nil + } + ct := new(tfheCiphertext) err := ct.deserializeCompact(ctBytes, ctType) if err != nil { diff --git a/core/vm/evm.go b/core/vm/evm.go index 9be07c6ab..b4c9ec52d 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -159,19 +159,23 @@ type EVM struct { // The logger allows the EVM to report information during execution. Logger Logger + + // An integer used as a counter for unique ciphertext hashes during gas estimation. + nextCiphertextHashOnGasEst uint256.Int } // NewEVM returns a new EVM. The returned EVM is not thread safe and should // only ever be used *once*. func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig *params.ChainConfig, config Config) *EVM { evm := &EVM{ - Context: blockCtx, - TxContext: txCtx, - StateDB: statedb, - Config: config, - chainConfig: chainConfig, - chainRules: chainConfig.Rules(blockCtx.BlockNumber, blockCtx.Random != nil), - Logger: &defaultLogger{}, + Context: blockCtx, + TxContext: txCtx, + StateDB: statedb, + Config: config, + chainConfig: chainConfig, + chainRules: chainConfig.Rules(blockCtx.BlockNumber, blockCtx.Random != nil), + Logger: &defaultLogger{}, + nextCiphertextHashOnGasEst: *uint256.NewInt(0), } evm.interpreter = NewEVMInterpreter(evm, config) return evm diff --git a/core/vm/instructions.go b/core/vm/instructions.go index ca89c21f7..1516b9aa7 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -563,6 +563,10 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont ct, ok := interpreter.verifiedCiphertexts[val] if ok { // If already existing in memory, skip storage and import the same ciphertext at the current depth. + // + // Also works for gas estimation - we don't persist anything to protected storage during gas estimation. + // However, ciphertexts remain in memory for the duration of the call, allowing for this lookup to find it. + // Note that even if a ciphertext has an empty verification depth set, it still remains in memory. importCiphertextToEVM(interpreter, ct.ciphertext) return } @@ -721,7 +725,8 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b newValHash := common.BytesToHash(newValBytes) oldValHash := interpreter.evm.StateDB.GetState(scope.Contract.Address(), common.Hash(loc.Bytes32())) protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address()) - if newValHash != oldValHash { + // If the value is the same or if we are not going to commit, don't do anything to protected storage. + if newValHash != oldValHash && interpreter.evm.Commit { // Since the old value is no longer stored in actual contract storage, run garbage collection on protected storage. garbageCollectProtectedStorage(oldValHash, protectedStorage, interpreter) // If a verified ciphertext, persist to protected storage. diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index 0f7c95df7..22faa5a8b 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -848,7 +848,7 @@ const ( type tfheCiphertext struct { ptr unsafe.Pointer serialization []byte - hash []byte + hash *common.Hash value *big.Int fheUintType fheUintType } @@ -1285,13 +1285,15 @@ func (ct *tfheCiphertext) setPtr(ptr unsafe.Pointer) { } func (ct *tfheCiphertext) getHash() common.Hash { + if ct.hash != nil { + return *ct.hash + } if !ct.initialized() { panic("cannot get hash of non-initialized ciphertext") } - if ct.hash == nil { - ct.hash = crypto.Keccak256(ct.serialize()) - } - return common.BytesToHash(ct.hash) + hash := common.BytesToHash(crypto.Keccak256(ct.serialize())) + ct.hash = &hash + return *ct.hash } func (ct *tfheCiphertext) availableForOps() bool {