Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/kms: call first version of KMS for reencrypt and decrypt #47

Merged
merged 14 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 232 additions & 32 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package fhevm

import (
"bytes"
"encoding/hex"
"errors"
"math/big"
"strings"
"testing"
Expand Down Expand Up @@ -65,6 +68,81 @@ func toPrecompileInput(isScalar bool, hashes ...common.Hash) []byte {
return ret
}

func toPrecompileInputNoScalar(isScalar bool, hashes ...common.Hash) []byte {
ret := make([]byte, 0)
for _, hash := range hashes {
ret = append(ret, hash.Bytes()...)
}
return ret
}

func evaluateRemainingOptimisticRequiresWithoutKms(environment EVMEnvironment) (bool, error) {
requires := environment.FhevmData().optimisticRequires
len := len(requires)
defer func() { environment.FhevmData().optimisticRequires = make([]*tfheCiphertext, 0) }()
if len != 0 {
var cumulative *tfheCiphertext = requires[0]
var err error
for i := 1; i < len; i++ {
cumulative, err = cumulative.bitand(requires[i])
if err != nil {
environment.GetLogger().Error("evaluateRemainingOptimisticRequires bitand failed", "err", err)
return false, err
}
}
result, err := cumulative.decrypt()
return result.Uint64() != 0, err
}
return true, nil
}

func decryptRunWithoutKms(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := environment.GetLogger()
// if not gas estimation and not view function fail if decryptions are disabled in transactions
if environment.IsCommitting() && !environment.IsEthCall() && environment.FhevmParams().DisableDecryptionsInTransaction {
msg := "decryptions during transaction are disabled"
logger.Error(msg, "input", hex.EncodeToString(input))
return nil, errors.New(msg)
}
if len(input) != 32 {
msg := "decrypt input len must be 32 bytes"
logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input))
return nil, errors.New(msg)
}
ct := getVerifiedCiphertext(environment, common.BytesToHash(input))
if ct == nil {
msg := "decrypt unverified handle"
logger.Error(msg, "input", hex.EncodeToString(input))
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip decryption and make sure we return the maximum possible value.
// We need that, because non-zero bytes cost more than zero bytes in some contexts (e.g. SSTORE or memory operations).
if !environment.IsCommitting() && !environment.IsEthCall() {
return bytes.Repeat([]byte{0xFF}, 32), nil
}
// Make sure we don't decrypt before any optimistic requires are checked.
optReqResult, optReqErr := evaluateRemainingOptimisticRequiresWithoutKms(environment)
if optReqErr != nil {
return nil, optReqErr
} else if !optReqResult {
return nil, ErrExecutionReverted
}

plaintext, err := ct.ciphertext.decrypt()
if err != nil {
logger.Error("decrypt failed", "err", err)
return nil, err
}

logger.Info("decrypt success", "plaintext", plaintext)

// Always return a 32-byte big-endian integer.
ret := make([]byte, 32)
plaintext.FillBytes(ret)
return ret, nil
}

var scalarBytePadding = make([]byte, 31)

func toLibPrecompileInput(method string, isScalar bool, hashes ...common.Hash) []byte {
Expand Down Expand Up @@ -1249,6 +1327,44 @@ func FheLibRandBounded(t *testing.T, fheUintType FheUintType, upperBound64 uint6
}
}

func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
var second, third uint64
switch fheUintType {
case FheUint8:
second = 2
third = 1
case FheUint16:
second = 4283
third = 1337
case FheUint32:
second = 1333337
third = 133337
}
signature := "fheIfThenElse(uint256,uint256,uint256)"
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheUint8).getHash()
secondHash := verifyCiphertextInTestMemory(environment, second, depth, fheUintType).getHash()
thirdHash := verifyCiphertextInTestMemory(environment, third, depth, fheUintType).getHash()
input := toLibPrecompileInputNoScalar(signature, firstHash, secondHash, thirdHash)
out, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf("VALUE %v", len(input))
// t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || condition == 1 && decrypted.Uint64() != second || condition == 0 && decrypted.Uint64() != third {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}

func LibTrivialEncrypt(t *testing.T, fheUintType FheUintType) {
var value big.Int
switch fheUintType {
Expand Down Expand Up @@ -1346,29 +1462,30 @@ func TestLibVerifyCiphertextInvalidType(t *testing.T) {
}
}

func TestLibReencrypt(t *testing.T) {
signature := "reencrypt(uint256,uint256)"
hashRes := crypto.Keccak256([]byte(signature))
signatureBytes := hashRes[0:4]
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
environment.ethCall = true
toEncrypt := 7
fheUintType := FheUint8
encCiphertext := verifyCiphertextInTestMemory(environment, uint64(toEncrypt), depth, fheUintType).getHash()
addr := common.Address{}
readOnly := false
input := make([]byte, 0)
input = append(input, signatureBytes...)
input = append(input, encCiphertext.Bytes()...)
// just append twice not to generate public key
input = append(input, encCiphertext.Bytes()...)
_, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf("Reencrypt error: %s", err.Error())
}
}
// TODO: can be enabled if mocking kms or running a kms during tests
// func TestLibReencrypt(t *testing.T) {
// signature := "reencrypt(uint256,uint256)"
// hashRes := crypto.Keccak256([]byte(signature))
// signatureBytes := hashRes[0:4]
// depth := 1
// environment := newTestEVMEnvironment()
// environment.depth = depth
// environment.ethCall = true
// toEncrypt := 7
// fheUintType := FheUint8
// encCiphertext := verifyCiphertextInTestMemory(environment, uint64(toEncrypt), depth, fheUintType).getHash()
// addr := common.Address{}
// readOnly := false
// input := make([]byte, 0)
// input = append(input, signatureBytes...)
// input = append(input, encCiphertext.Bytes()...)
// // just append twice not to generate public key
// input = append(input, encCiphertext.Bytes()...)
// _, err := FheLibRun(environment, addr, addr, input, readOnly)
// if err != nil {
// t.Fatalf("Reencrypt error: %s", err.Error())
// }
// }

func TestLibCast(t *testing.T) {
signature := "cast(uint256,bytes1)"
Expand All @@ -1389,7 +1506,7 @@ func TestLibCast(t *testing.T) {
input = append(input, byte(FheUint32))
_, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf("Reencrypt error: %s", err.Error())
t.Fatalf("Cast error: %s", err.Error())
}
}

Expand Down Expand Up @@ -2352,6 +2469,43 @@ func FheNot(t *testing.T, fheUintType FheUintType, scalar bool) {
}
}

func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 4283
rhs = 1337
case FheUint32:
lhs = 1333337
rhs = 133337
}
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
addr := common.Address{}
readOnly := false
conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, fheUintType).getHash()
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).getHash()

input1 := toPrecompileInputNoScalar(false, conditionHash, lhsHash, rhsHash)
out, err := fheIfThenElseRun(environment, addr, addr, input1, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(environment, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || condition == 1 && decrypted.Uint64() != lhs || condition == 0 && decrypted.Uint64() != rhs {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0)
}
}

func Decrypt(t *testing.T, fheUintType FheUintType) {
var value uint64
switch fheUintType {
Expand All @@ -2368,7 +2522,7 @@ func Decrypt(t *testing.T, fheUintType FheUintType) {
addr := common.Address{}
readOnly := false
hash := verifyCiphertextInTestMemory(environment, value, depth, fheUintType).getHash()
out, err := decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
out, err := decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err != nil {
t.Fatalf(err.Error())
} else if len(out) != 32 {
Expand Down Expand Up @@ -2627,13 +2781,29 @@ func TestFheLibRandBounded32(t *testing.T) {
FheLibRandBounded(t, FheUint32, 32)
}

func TestFheLibIfThenElse8(t *testing.T) {
FheLibIfThenElse(t, FheUint8, 1)
FheLibIfThenElse(t, FheUint8, 0)
}

func TestFheLibIfThenElse16(t *testing.T) {
FheLibIfThenElse(t, FheUint16, 1)
FheLibIfThenElse(t, FheUint16, 0)
}

func TestFheLibIfThenElse32(t *testing.T) {
FheLibIfThenElse(t, FheUint32, 1)
FheLibIfThenElse(t, FheUint32, 0)
}

func TestFheLibTrivialEncrypt8(t *testing.T) {
LibTrivialEncrypt(t, FheUint8)
}

func TestLibDecrypt8(t *testing.T) {
LibDecrypt(t, FheUint8)
}
// TODO: can be enabled if mocking kms or running a kms during tests
// func TestLibDecrypt8(t *testing.T) {
// LibDecrypt(t, FheUint8)
// }

func TestFheAdd8(t *testing.T) {
FheAdd(t, FheUint8, false)
Expand Down Expand Up @@ -3079,6 +3249,21 @@ func TestFheNot32(t *testing.T) {
FheNot(t, FheUint32, false)
}

func TestFheIfThenElse8(t *testing.T) {
FheIfThenElse(t, FheUint8, 1)
FheIfThenElse(t, FheUint8, 0)
}

func TestFheIfThenElse16(t *testing.T) {
FheIfThenElse(t, FheUint16, 1)
FheIfThenElse(t, FheUint16, 0)
}

func TestFheIfThenElse32(t *testing.T) {
FheIfThenElse(t, FheUint32, 1)
FheIfThenElse(t, FheUint32, 0)
}

func TestFheScalarMax8(t *testing.T) {
FheMax(t, FheUint8, true)
}
Expand Down Expand Up @@ -3316,12 +3501,27 @@ func TestFheRandBoundedEthCall(t *testing.T) {
}
}

func EvalRemOptReqWhenStopTokenWithoutKms(env EVMEnvironment) (err error) {
err = nil
// If we are finishing execution (about to go to from depth 1 to depth 0), evaluate
// any remaining optimistic requires.
if env.GetDepth() == 1 {
result, evalErr := evaluateRemainingOptimisticRequiresWithoutKms(env)
if evalErr != nil {
err = evalErr
} else if !result {
err = ErrExecutionReverted
}
}
return err
}

func interpreterRunWithStopContract(environment *MockEVMEnvironment, interpreter *vm.EVMInterpreter, contract *vm.Contract, input []byte, readOnly bool) (ret []byte, err error) {
ret, _ = interpreter.Run(contract, input, readOnly)
// the following functions are meant to be ran from within interpreter.run so we increment depth to emulate that
environment.depth++
RemoveVerifiedCipherextsAtCurrentDepth(environment)
err = EvalRemOptReqWhenStopToken(environment)
err = EvalRemOptReqWhenStopTokenWithoutKms(environment)
environment.depth--
return ret, err
}
Expand Down Expand Up @@ -3508,7 +3708,7 @@ func TestDecryptWithFalseOptimisticRequire(t *testing.T) {
t.Fatalf("require expected output len of 0, got %v", len(out))
}
// Call decrypt and expect it to fail due to the optimistic require being false.
_, err = decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
_, err = decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err == nil {
t.Fatalf("expected decrypt fails due to false optimistic require")
}
Expand All @@ -3533,7 +3733,7 @@ func TestDecryptWithTrueOptimisticRequire(t *testing.T) {
t.Fatalf("require expected output len of 0, got %v", len(out))
}
// Call decrypt and expect it to succeed due to the optimistic require being true.
out, err = decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
out, err = decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err != nil {
t.Fatalf(err.Error())
} else if len(out) != 32 {
Expand All @@ -3556,7 +3756,7 @@ func TestDecryptInTransactionDisabled(t *testing.T) {
readOnly := false
hash := verifyCiphertextInTestMemory(environment, 1, depth, FheUint8).getHash()
// Call decrypt and expect it to fail due to disabling of decryptions during commit
_, err := decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
_, err := decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err == nil {
t.Fatalf("expected to error out in test")
} else if err.Error() != "decryptions during transaction are disabled" {
Expand Down
20 changes: 20 additions & 0 deletions fhevm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,26 @@ func get2VerifiedOperands(environment EVMEnvironment, input []byte) (lhs *verifi
return
}

func get3VerifiedOperands(environment EVMEnvironment, input []byte) (first *verifiedCiphertext, second *verifiedCiphertext, third *verifiedCiphertext, err error) {
if len(input) != 96 {
return nil, nil, nil, errors.New("input needs to contain three 256-bit sized values")
}
first = getVerifiedCiphertext(environment, common.BytesToHash(input[0:32]))
if first == nil {
return nil, nil, nil, errors.New("unverified ciphertext handle")
}
second = getVerifiedCiphertext(environment, common.BytesToHash(input[32:64]))
if second == nil {
return nil, nil, nil, errors.New("unverified ciphertext handle")
}
third = getVerifiedCiphertext(environment, common.BytesToHash(input[64:96]))
if third == nil {
return nil, nil, nil, errors.New("unverified ciphertext handle")
}
err = nil
return
}

func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *verifiedCiphertext, rhs *big.Int, err error) {
if len(input) != 65 {
return nil, nil, errors.New("input needs to contain two 256-bit sized values and 1 8-bit value")
Expand Down
Loading