diff --git a/fhevm/evm.go b/fhevm/evm.go index ae473ce..826b17e 100644 --- a/fhevm/evm.go +++ b/fhevm/evm.go @@ -44,7 +44,7 @@ func makeKeccakSignature(input string) uint32 { return binary.BigEndian.Uint32(crypto.Keccak256([]byte(input))[0:4]) } -func isScalarOp(environment *EVMEnvironment, input []byte) (bool, error) { +func isScalarOp(input []byte) (bool, error) { if len(input) != 65 { return false, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") } @@ -52,11 +52,11 @@ func isScalarOp(environment *EVMEnvironment, input []byte) (bool, error) { return isScalar, nil } -func getVerifiedCiphertext(environment *EVMEnvironment, ciphertextHash common.Hash) *verifiedCiphertext { - return getVerifiedCiphertextFromEVM(*environment, ciphertextHash) +func getVerifiedCiphertext(environment EVMEnvironment, ciphertextHash common.Hash) *verifiedCiphertext { + return getVerifiedCiphertextFromEVM(environment, ciphertextHash) } -func get2VerifiedOperands(environment *EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *verifiedCiphertext, err error) { +func get2VerifiedOperands(environment EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *verifiedCiphertext, err error) { if len(input) != 65 { return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") } @@ -72,7 +72,7 @@ func get2VerifiedOperands(environment *EVMEnvironment, input []byte) (lhs *verif return } -func getScalarOperands(environment *EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *big.Int, err error) { +func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *big.Int, err error) { if len(input) != 65 { return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value") } @@ -85,8 +85,8 @@ func getScalarOperands(environment *EVMEnvironment, input []byte) (lhs *verified return } -func importCiphertextToEVMAtDepth(environment *EVMEnvironment, ct *tfheCiphertext, depth int) *verifiedCiphertext { - existing, ok := (*environment).GetFhevmData().verifiedCiphertexts[ct.getHash()] +func importCiphertextToEVMAtDepth(environment EVMEnvironment, ct *tfheCiphertext, depth int) *verifiedCiphertext { + existing, ok := environment.GetFhevmData().verifiedCiphertexts[ct.getHash()] if ok { existing.verifiedDepths.add(depth) return existing @@ -97,21 +97,21 @@ func importCiphertextToEVMAtDepth(environment *EVMEnvironment, ct *tfheCiphertex verifiedDepths, ct, } - (*environment).GetFhevmData().verifiedCiphertexts[ct.getHash()] = new + environment.GetFhevmData().verifiedCiphertexts[ct.getHash()] = new return new } } -func importCiphertextToEVM(environment *EVMEnvironment, ct *tfheCiphertext) *verifiedCiphertext { - return importCiphertextToEVMAtDepth(environment, ct, (*environment).GetDepth()) +func importCiphertextToEVM(environment EVMEnvironment, ct *tfheCiphertext) *verifiedCiphertext { + return importCiphertextToEVMAtDepth(environment, ct, environment.GetDepth()) } -func importCiphertext(environment *EVMEnvironment, ct *tfheCiphertext) *verifiedCiphertext { +func importCiphertext(environment EVMEnvironment, ct *tfheCiphertext) *verifiedCiphertext { return importCiphertextToEVM(environment, ct) } -func importRandomCiphertext(environment *EVMEnvironment, t fheUintType) []byte { - nextCtHash := &(*environment).GetFhevmData().nextCiphertextHashOnGasEst +func importRandomCiphertext(environment EVMEnvironment, t fheUintType) []byte { + nextCtHash := &environment.GetFhevmData().nextCiphertextHashOnGasEst ctHashBytes := crypto.Keccak256(nextCtHash.Bytes()) handle := common.BytesToHash(ctHashBytes) ct := new(tfheCiphertext) @@ -129,3 +129,40 @@ func minInt(a int, b int) int { } return b } + +// Return a memory with a layout that matches the `bytes` EVM type, namely: +// - 32 byte integer in big-endian order as length +// - the actual bytes in the `bytes` value +// - add zero byte padding until nearest multiple of 32 +func toEVMBytes(input []byte) []byte { + arrLen := uint64(len(input)) + lenBytes32 := uint256.NewInt(arrLen).Bytes32() + ret := make([]byte, 0, arrLen+32) + ret = append(ret, lenBytes32[:]...) + ret = append(ret, input...) + return ret +} + +func InitFhevm(accessibleState EVMEnvironment) { + persistFhePubKeyHash(accessibleState) +} + +func persistFhePubKeyHash(accessibleState EVMEnvironment) { + existing := accessibleState.GetState(fhePubKeyHashPrecompile, fhePubKeyHashSlot) + if newInt(existing[:]).IsZero() { + accessibleState.SetState(fhePubKeyHashPrecompile, fhePubKeyHashSlot, pksHash) + } +} + +// apply padding to slice to the multiple of 32 +func padArrayTo32Multiple(input []byte) []byte { + modRes := len(input) % 32 + if modRes > 0 { + padding := 32 - modRes + for padding > 0 { + padding-- + input = append(input, 0x0) + } + } + return input +} diff --git a/fhevm/interface.go b/fhevm/interface.go index 8046ad9..2c9a644 100644 --- a/fhevm/interface.go +++ b/fhevm/interface.go @@ -33,3 +33,10 @@ type FhevmData struct { nextCiphertextHashOnGasEst uint256.Int } + +func NewFhevmData() FhevmData { + return FhevmData{ + verifiedCiphertexts: make(map[common.Hash]*verifiedCiphertext), + optimisticRequires: make([]*tfheCiphertext, 0), + } +} diff --git a/fhevm/precompiles.go b/fhevm/precompiles.go index d58e97e..757ca2d 100644 --- a/fhevm/precompiles.go +++ b/fhevm/precompiles.go @@ -1,9 +1,11 @@ package fhevm import ( + "bytes" "encoding/binary" "encoding/hex" "errors" + "math/big" "github.com/ethereum/go-ethereum/common" "github.com/zama-ai/fhevm-go/params" @@ -17,16 +19,16 @@ type PrecompiledContract interface { Run(environment *EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) (ret []byte, err error) } -var PrecompiledContracts = map[common.Address]PrecompiledContract{ - common.BytesToAddress([]byte{93}): &fheLib{}, -} +var ErrExecutionReverted = errors.New("execution reverted") var signatureFheAdd = makeKeccakSignature("fheAdd(uint256,uint256,bytes1)") +var signatureCast = makeKeccakSignature("cast(uint256,bytes1)") +var signatureDecrypt = makeKeccakSignature("decrypt(uint256)") +var signatureFhePubKey = makeKeccakSignature("fhePubKey(bytes1)") +var signatureTrivialEncrypt = makeKeccakSignature("trivialEncrypt(uint256,bytes1)") -type fheLib struct{} - -func (e *fheLib) RequiredGas(environment *EVMEnvironment, input []byte) uint64 { - logger := (*environment).GetLogger() +func FheLibRequiredGas(environment EVMEnvironment, input []byte) uint64 { + logger := environment.GetLogger() if len(input) < 4 { err := errors.New("input must contain at least 4 bytes for method signature") logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) @@ -37,6 +39,18 @@ func (e *fheLib) RequiredGas(environment *EVMEnvironment, input []byte) uint64 { case signatureFheAdd: bwCompatBytes := input[4:minInt(69, len(input))] return fheAddRequiredGas(environment, bwCompatBytes) + case signatureCast: + bwCompatBytes := input[4:minInt(37, len(input))] + return castRequiredGas(environment, bwCompatBytes) + case signatureDecrypt: + bwCompatBytes := input[4:minInt(36, len(input))] + return decryptRequiredGas(environment, bwCompatBytes) + case signatureFhePubKey: + bwCompatBytes := input[4:minInt(5, len(input))] + return fhePubKeyRequiredGas(environment, bwCompatBytes) + case signatureTrivialEncrypt: + bwCompatBytes := input[4:minInt(37, len(input))] + return trivialEncryptRequiredGas(environment, bwCompatBytes) default: err := errors.New("precompile method not found") logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) @@ -44,8 +58,8 @@ func (e *fheLib) RequiredGas(environment *EVMEnvironment, input []byte) uint64 { } } -func (e *fheLib) Run(environment *EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { - logger := (*environment).GetLogger() +func FheLibRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := environment.GetLogger() if len(input) < 4 { err := errors.New("input must contain at least 4 bytes for method signature") logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) @@ -56,6 +70,26 @@ func (e *fheLib) Run(environment *EVMEnvironment, caller common.Address, addr co case signatureFheAdd: bwCompatBytes := input[4:minInt(69, len(input))] return fheAddRun(environment, caller, addr, bwCompatBytes, readOnly) + case signatureCast: + bwCompatBytes := input[4:minInt(37, len(input))] + return castRun(environment, caller, addr, bwCompatBytes, readOnly) + case signatureDecrypt: + bwCompatBytes := input[4:minInt(36, len(input))] + return decryptRun(environment, caller, addr, bwCompatBytes, readOnly) + case signatureFhePubKey: + bwCompatBytes := input[4:minInt(5, len(input))] + precompileBytes, err := fhePubKeyRun(environment, caller, addr, bwCompatBytes, readOnly) + if err != nil { + return precompileBytes, err + } + // pad according to abi specification, first add offset to the dynamic bytes argument + outputBytes := make([]byte, 32, len(precompileBytes)+32) + outputBytes[31] = 0x20 + outputBytes = append(outputBytes, precompileBytes...) + return padArrayTo32Multiple(outputBytes), nil + case signatureTrivialEncrypt: + bwCompatBytes := input[4:minInt(37, len(input))] + return trivialEncryptRun(environment, caller, addr, bwCompatBytes, readOnly) default: err := errors.New("precompile method not found") logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) @@ -69,9 +103,16 @@ var fheAddSubGasCosts = map[fheUintType]uint64{ FheUint32: params.FheUint32AddSubGas, } -func fheAddRequiredGas(environment *EVMEnvironment, input []byte) uint64 { - logger := (*environment).GetLogger() - isScalar, err := isScalarOp(environment, input) +var fheDecryptGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8DecryptGas, + FheUint16: params.FheUint16DecryptGas, + FheUint32: params.FheUint32DecryptGas, +} + +// Gas costs +func fheAddRequiredGas(environment EVMEnvironment, input []byte) uint64 { + logger := environment.GetLogger() + isScalar, err := isScalarOp(input) if err != nil { logger.Error("fheAdd/Sub RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return 0 @@ -98,10 +139,55 @@ func fheAddRequiredGas(environment *EVMEnvironment, input []byte) uint64 { return fheAddSubGasCosts[lhs.ciphertext.fheUintType] } -func fheAddRun(environment *EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { - logger := (*environment).GetLogger() +func castRequiredGas(environment EVMEnvironment, input []byte) uint64 { + if len(input) != 33 { + environment.GetLogger().Error( + "cast RequiredGas() input needs to contain a ciphertext and one byte for its type", + "len", len(input)) + return 0 + } + return params.FheCastGas +} + +func decryptRequiredGas(environment EVMEnvironment, input []byte) uint64 { + logger := environment.GetLogger() + if len(input) != 32 { + logger.Error("decrypt RequiredGas() input len must be 32 bytes", "input", hex.EncodeToString(input), "len", len(input)) + return 0 + } + ct := getVerifiedCiphertext(environment, common.BytesToHash(input)) + if ct == nil { + logger.Error("decrypt RequiredGas() input doesn't point to verified ciphertext", "input", hex.EncodeToString(input)) + return 0 + } + return fheDecryptGasCosts[ct.ciphertext.fheUintType] +} + +func fhePubKeyRequiredGas(accessibleState EVMEnvironment, input []byte) uint64 { + return params.FhePubKeyGas +} + +var fheTrivialEncryptGasCosts = map[fheUintType]uint64{ + FheUint8: params.FheUint8TrivialEncryptGas, + FheUint16: params.FheUint16TrivialEncryptGas, + FheUint32: params.FheUint32TrivialEncryptGas, +} - isScalar, err := isScalarOp(environment, input) +func trivialEncryptRequiredGas(accessibleState EVMEnvironment, input []byte) uint64 { + logger := accessibleState.GetLogger() + if len(input) != 33 { + logger.Error("trivialEncrypt RequiredGas() input len must be 33 bytes", "input", hex.EncodeToString(input), "len", len(input)) + return 0 + } + encryptToType := fheUintType(input[32]) + return fheTrivialEncryptGasCosts[encryptToType] +} + +// Implementations +func fheAddRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := environment.GetLogger() + + isScalar, err := isScalarOp(input) if err != nil { logger.Error("fheAdd can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input)) return nil, err @@ -120,7 +206,7 @@ func fheAddRun(environment *EVMEnvironment, caller common.Address, addr common.A } // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !(*environment).IsCommitting() && !(*environment).IsEthCall() { + if !environment.IsCommitting() && !environment.IsEthCall() { return importRandomCiphertext(environment, lhs.ciphertext.fheUintType), nil } @@ -143,7 +229,7 @@ func fheAddRun(environment *EVMEnvironment, caller common.Address, addr common.A } // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !(*environment).IsCommitting() && !(*environment).IsEthCall() { + if !environment.IsCommitting() && !environment.IsEthCall() { return importRandomCiphertext(environment, lhs.ciphertext.fheUintType), nil } @@ -159,3 +245,154 @@ func fheAddRun(environment *EVMEnvironment, caller common.Address, addr common.A return resultHash[:], nil } } + +func decryptRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := environment.GetLogger() + if len(input) != 32 { + msg := "decrypt input len must be 32 bytes" + logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input)) + return nil, errors.New(msg) + } + ct := getVerifiedCiphertext(environment, common.BytesToHash(input)) + if ct == nil { + msg := "decrypt unverified handle" + logger.Error(msg, "input", hex.EncodeToString(input)) + return nil, errors.New(msg) + } + // If we are doing gas estimation, skip decryption and make sure we return the maximum possible value. + // We need that, because non-zero bytes cost more than zero bytes in some contexts (e.g. SSTORE or memory operations). + if !environment.IsCommitting() && !environment.IsEthCall() { + return bytes.Repeat([]byte{0xFF}, 32), nil + } + // Make sure we don't decrypt before any optimistic requires are checked. + optReqResult, optReqErr := evaluateRemainingOptimisticRequires(environment) + if optReqErr != nil { + return nil, optReqErr + } else if !optReqResult { + return nil, ErrExecutionReverted + } + plaintext, err := decryptValue(ct.ciphertext) + if err != nil { + logger.Error("decrypt failed", "err", err) + return nil, err + } + // Always return a 32-byte big-endian integer. + ret := make([]byte, 32) + bigIntValue := big.NewInt(0) + bigIntValue.SetUint64(plaintext) + bigIntValue.FillBytes(ret) + return ret, nil +} + +func decryptValue(ct *tfheCiphertext) (uint64, error) { + v, err := ct.decrypt() + return v.Uint64(), err +} + +// If there are optimistic requires, check them by doing bitwise AND on all of them. +// That works, because we assume their values are either 0 or 1. If there is at least +// one 0, the result will be 0 (false). +func evaluateRemainingOptimisticRequires(environment EVMEnvironment) (bool, error) { + requires := environment.GetFhevmData().optimisticRequires + len := len(requires) + defer func() { requires = make([]*tfheCiphertext, 0) }() + if len != 0 { + var cumulative *tfheCiphertext = requires[0] + var err error + for i := 1; i < len; i++ { + cumulative, err = cumulative.bitand(requires[i]) + if err != nil { + environment.GetLogger().Error("evaluateRemainingOptimisticRequires bitand failed", "err", err) + return false, err + } + } + result, err := decryptValue(cumulative) + return result != 0, err + } + return true, nil +} + +func castRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := environment.GetLogger() + if len(input) != 33 { + msg := "cast Run() input needs to contain a ciphertext and one byte for its type" + logger.Error(msg, "len", len(input)) + return nil, errors.New(msg) + } + + ct := getVerifiedCiphertext(environment, common.BytesToHash(input[0:32])) + if ct == nil { + logger.Error("cast input not verified") + return nil, errors.New("unverified ciphertext handle") + } + + if !isValidType(input[32]) { + logger.Error("invalid type to cast to") + return nil, errors.New("invalid type provided") + } + castToType := fheUintType(input[32]) + + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. + if !environment.IsCommitting() && !environment.IsEthCall() { + return importRandomCiphertext(environment, castToType), nil + } + + res, err := ct.ciphertext.castTo(castToType) + if err != nil { + msg := "cast Run() error casting ciphertext to" + logger.Error(msg, "type", castToType) + return nil, errors.New(msg) + } + + resHash := res.getHash() + + importCiphertext(environment, res) + if environment.IsCommitting() { + logger.Info("cast success", + "ctHash", resHash.Hex(), + ) + } + + return resHash.Bytes(), nil +} + +var fhePubKeyHashPrecompile = common.BytesToAddress([]byte{93}) +var fhePubKeyHashSlot = common.Hash{} + +func fhePubKeyRun(accessibleState EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + existing := accessibleState.GetState(fhePubKeyHashPrecompile, fhePubKeyHashSlot) + if existing != pksHash { + msg := "fhePubKey FHE public key hash doesn't match one stored in state" + accessibleState.GetLogger().Error(msg, "existing", existing.Hex(), "pksHash", pksHash.Hex()) + return nil, errors.New(msg) + } + // If we have a single byte with the value of 1, return as an EVM array. Otherwise, returh the raw bytes. + if len(input) == 1 && input[0] == 1 { + return toEVMBytes(pksBytes), nil + } else { + return pksBytes, nil + } +} + +func trivialEncryptRun(accessibleState EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.GetLogger() + if len(input) != 33 { + msg := "trivialEncrypt input len must be 33 bytes" + logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input)) + return nil, errors.New(msg) + } + + valueToEncrypt := *new(big.Int).SetBytes(input[0:32]) + encryptToType := fheUintType(input[32]) + + ct := new(tfheCiphertext).trivialEncrypt(valueToEncrypt, encryptToType) + + ctHash := ct.getHash() + importCiphertext(accessibleState, ct) + if accessibleState.IsCommitting() { + logger.Info("trivialEncrypt success", + "ctHash", ctHash.Hex(), + "valueToEncrypt", valueToEncrypt.Uint64()) + } + return ctHash.Bytes(), nil +} diff --git a/fhevm/test-fhevm-keys/cks b/fhevm/test-fhevm-keys/cks new file mode 100644 index 0000000..1ea9026 Binary files /dev/null and b/fhevm/test-fhevm-keys/cks differ diff --git a/fhevm/test-fhevm-keys/pks b/fhevm/test-fhevm-keys/pks new file mode 100644 index 0000000..4a075d0 Binary files /dev/null and b/fhevm/test-fhevm-keys/pks differ diff --git a/fhevm/test-fhevm-keys/sks b/fhevm/test-fhevm-keys/sks new file mode 100644 index 0000000..a1682a8 Binary files /dev/null and b/fhevm/test-fhevm-keys/sks differ diff --git a/fhevm/tfhe.go b/fhevm/tfhe.go index 685767f..3eb6616 100644 --- a/fhevm/tfhe.go +++ b/fhevm/tfhe.go @@ -18,7 +18,7 @@ package fhevm /* #cgo CFLAGS: -O3 -I../tfhe-rs/target/release -#cgo LDFLAGS: -L../tfhe-rs/target/release -l:libtfhe.a -lm +#cgo LDFLAGS: -L../tfhe-rs/target/release -ltfhe -lm #include @@ -1432,10 +1432,10 @@ void* cast_32_16(void* ct, void* sks) { import "C" import ( + _ "embed" "errors" "fmt" "math/big" - "os" "unsafe" "github.com/ethereum/go-ethereum/common" @@ -1449,14 +1449,6 @@ func toBufferView(in []byte) C.BufferView { } } -func homeDir() string { - home, err := os.UserHomeDir() - if err != nil { - panic(err) - } - return home -} - // Expanded TFHE ciphertext sizes by type, in bytes. var expandedFheCiphertextSize map[fheUintType]uint @@ -1466,36 +1458,30 @@ var compactFheCiphertextSize map[fheUintType]uint var sks unsafe.Pointer var cks unsafe.Pointer var pks unsafe.Pointer -var pksBytes []byte var pksHash common.Hash var networkKeysDir string var usersKeysDir string +//go:embed test-fhevm-keys/sks +var sksBytes []byte + +//go:embed test-fhevm-keys/pks +var pksBytes []byte + +//go:embed test-fhevm-keys/cks +var cksBytes []byte + func init() { expandedFheCiphertextSize = make(map[fheUintType]uint) compactFheCiphertextSize = make(map[fheUintType]uint) - home := homeDir() - networkKeysDir = home + "/.evmosd/zama/keys/network-fhe-keys/" - usersKeysDir = home + "/.evmosd/zama/keys/users-fhe-keys/" - - sksBytes, err := os.ReadFile(networkKeysDir + "sks") - if err != nil { - fmt.Println("WARNING: file sks not found.") - return - } + fmt.Println("TODO: fhevm keys are hardcoded into subnet evm, find ways to store keys in production") sks = C.deserialize_server_key(toBufferView(sksBytes)) expandedFheCiphertextSize[FheUint8] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint8).serialize())) expandedFheCiphertextSize[FheUint16] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint16).serialize())) expandedFheCiphertextSize[FheUint32] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint32).serialize())) - pksBytes, err = os.ReadFile(networkKeysDir + "pks") - if err != nil { - pksBytes = nil - fmt.Println("WARNING: file pks not found.") - return - } pksHash = crypto.Keccak256Hash(pksBytes) pks = C.deserialize_compact_public_key(toBufferView(pksBytes)) @@ -1503,11 +1489,6 @@ func init() { compactFheCiphertextSize[FheUint16] = uint(len(encryptAndSerializeCompact(0, FheUint16))) compactFheCiphertextSize[FheUint32] = uint(len(encryptAndSerializeCompact(0, FheUint32))) - cksBytes, err := os.ReadFile(networkKeysDir + "cks") - if err != nil { - fmt.Println("WARNING: file cks not found.") - return - } cks = C.deserialize_client_key(toBufferView(cksBytes)) }