Skip to content

Commit

Permalink
Merge pull request #159 from zama-ai/petar/fhe-rand
Browse files Browse the repository at this point in the history
Reinstate plaintext fheRand with type support
  • Loading branch information
dartdart26 committed Sep 7, 2023
2 parents cc574ae + 84487ce commit 0fe94b1
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 85 deletions.
193 changes: 111 additions & 82 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/ethereum/go-ethereum/params"
"github.com/holiman/uint256"
"github.com/naoina/toml"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/nacl/box"
"golang.org/x/crypto/ripemd160"
)
Expand Down Expand Up @@ -75,7 +76,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -119,7 +120,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -164,7 +165,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -209,7 +210,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -254,7 +255,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -1471,6 +1472,12 @@ var fheTrivialEncryptGasCosts = map[fheUintType]uint64{
FheUint32: params.FheUint32TrivialEncryptGas,
}

var fheRandGasCosts = map[fheUintType]uint64{
FheUint8: params.FheUint8RandGas,
FheUint16: params.FheUint16RandGas,
FheUint32: params.FheUint32RandGas,
}

func writeResult(ct *tfheCiphertext, fileName string, logger Logger) {
os.WriteFile("/tmp/"+fileName, ct.serialize(), 0644)
}
Expand Down Expand Up @@ -3317,83 +3324,105 @@ func (e *fheNot) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return resultHash[:], nil
}

// type fheRand struct{}

// var globalRngSeed []byte

// var rngNonceKey [32]byte = uint256.NewInt(0).Bytes32()

// func init() {
// if chacha20.NonceSizeX != 24 {
// panic("expected 24 bytes for NonceSizeX")
// }

// // TODO: Since the current implementation is not FHE-based and, hence, not private,
// // we just initialize the global seed with non-random public data. We will change
// // that once the FHE version is available.
// globalRngSeed = make([]byte, chacha20.KeySize)
// for i := range globalRngSeed {
// globalRngSeed[i] = byte(1 + i)
// }
// }

// func (e *fheRand) RequiredGas(input []byte) uint64 {
// // TODO
// return 8
// }

// func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
// // If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
// if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
// return importRandomCiphertext(accessibleState), nil
// }

// // Get the RNG nonce.
// protectedStorage := crypto.CreateProtectedStorageContractAddress(caller)
// currentRngNonceBytes := accessibleState.Interpreter().evm.StateDB.GetState(protectedStorage, rngNonceKey).Bytes()

// // Increment the RNG nonce by 1.
// nextRngNonce := newInt(currentRngNonceBytes)
// nextRngNonce = nextRngNonce.AddUint64(nextRngNonce, 1)
// accessibleState.Interpreter().evm.StateDB.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32())

// // Compute the seed and use it to create a new cipher.
// hasher := crypto.NewKeccakState()
// hasher.Write(globalRngSeed)
// hasher.Write(caller.Bytes())
// hasher.Write(currentRngNonceBytes)
// seed := common.Hash{}
// _, err := hasher.Read(seed[:])
// if err != nil {
// return nil, err
// }
// // The RNG nonce bytes are of size chacha20.NonceSizeX, which is assumed to be 24 bytes (see init() above).
// // Since uint256.Int.z[0] is the least significant byte and since uint256.Int.Bytes32() serializes
// // in order of z[3], z[2], z[1], z[0], we want to essentially ignore the first byte, i.e. z[3], because
// // it will always be 0 as the nonce size is 24.
// cipher, err := chacha20.NewUnauthenticatedCipher(seed.Bytes(), currentRngNonceBytes[32-chacha20.NonceSizeX:32])
// if err != nil {
// return nil, err
// }

// // XOR a byte array of 0s with the stream from the cipher and receive the result in the same array.
// randBytes := make([]byte, 8)
// cipher.XORKeyStream(randBytes, randBytes)

// // Trivially encrypt the random integer.
// randInt := binary.BigEndian.Uint64(randBytes) % math.BigPow(2, 3).Uint64()
// randCt := new(tfheCiphertext)
// randCt.trivialEncrypt(randInt)
// importCiphertext(accessibleState, randCt)

// // TODO: for testing
// err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644)
// if err != nil {
// return nil, err
// }
// ctHash := randCt.getHash()
// return ctHash[:], nil
// }
type fheRand struct{}

var globalRngSeed []byte

var rngNonceKey [32]byte = uint256.NewInt(0).Bytes32()

func init() {
if chacha20.NonceSizeX != 24 {
panic("expected 24 bytes for NonceSizeX")
}

// TODO: Since the current implementation is not FHE-based and, hence, not private,
// we just initialize the global seed with non-random public data. We will change
// that once the FHE version is available.
globalRngSeed = make([]byte, chacha20.KeySize)
for i := range globalRngSeed {
globalRngSeed[i] = byte(1 + i)
}

// Make sure we mark the RNG nonce key as a reserved slot in protected storage.
reservedProtectedStorageSlots = append(reservedProtectedStorageSlots, common.BytesToHash(rngNonceKey[:]))
}

func (e *fheRand) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
logger := accessibleState.Interpreter().evm.Logger
if len(input) != 1 || !isValidType(input[0]) {
logger.Error("fheRand RequiredGas() input len must be at least 1 byte and be a valid FheUint type", "input", hex.EncodeToString(input), "len", len(input))
return 0
}
t := fheUintType(input[0])
return fheRandGasCosts[t]
}

func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := accessibleState.Interpreter().evm.Logger
if accessibleState.Interpreter().evm.EthCall {
msg := "fheRand cannot be called via EthCall, because it needs to mutate internal state"
logger.Error(msg)
return nil, errors.New(msg)
}
if len(input) != 1 || !isValidType(input[0]) {
msg := "fheRand input len must be at least 1 byte and be a valid FheUint type"
logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input))
return nil, errors.New(msg)
}

t := fheUintType(input[0])
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
return importRandomCiphertext(accessibleState, t), nil
}

// Get the RNG nonce.
protectedStorage := crypto.CreateProtectedStorageContractAddress(caller)
currentRngNonceBytes := accessibleState.Interpreter().evm.StateDB.GetState(protectedStorage, rngNonceKey).Bytes()

// Increment the RNG nonce by 1.
nextRngNonce := newInt(currentRngNonceBytes)
nextRngNonce = nextRngNonce.AddUint64(nextRngNonce, 1)
accessibleState.Interpreter().evm.StateDB.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32())

// Compute the seed and use it to create a new cipher.
hasher := crypto.NewKeccakState()
hasher.Write(globalRngSeed)
hasher.Write(caller.Bytes())
seed := common.Hash{}
_, err := hasher.Read(seed[:])
if err != nil {
return nil, err
}
// The RNG nonce bytes are of size chacha20.NonceSizeX, which is assumed to be 24 bytes (see init() above).
// Since uint256.Int.z[0] is the least significant byte and since uint256.Int.Bytes32() serializes
// in order of z[3], z[2], z[1], z[0], we want to essentially ignore the first byte, i.e. z[3], because
// it will always be 0 as the nonce size is 24.
cipher, err := chacha20.NewUnauthenticatedCipher(seed.Bytes(), currentRngNonceBytes[32-chacha20.NonceSizeX:32])
if err != nil {
return nil, err
}

// XOR a byte array of 0s with the stream from the cipher and receive the result in the same array.
randBytes := make([]byte, 8)
cipher.XORKeyStream(randBytes, randBytes)

// Trivially encrypt the random integer.
randUint64 := binary.BigEndian.Uint64(randBytes)
randCt := new(tfheCiphertext)
randBigInt := big.NewInt(0)
randBigInt.SetUint64(randUint64)
randCt.trivialEncrypt(*randBigInt, t)
importCiphertext(accessibleState, randCt)

// TODO: for testing
err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644)
if err != nil {
return nil, err
}
ctHash := randCt.getHash()
return ctHash[:], nil
}

type cast struct{}

Expand Down
85 changes: 85 additions & 0 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,30 @@ func Decrypt(t *testing.T, fheUintType fheUintType) {
}
}

func FheRand(t *testing.T, fheUintType fheUintType) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
out, err := c.Run(state, addr, addr, []byte{byte(fheUintType)}, readOnly)
if err != nil {
t.Fatalf(err.Error())
} else if len(out) != 32 {
t.Fatalf("fheRand expected output len of 32, got %v", len(out))
}
if len(state.interpreter.verifiedCiphertexts) != 1 {
t.Fatalf("fheRand expected 1 verified ciphertext")
}

hash := common.BytesToHash(out)
_, err = state.interpreter.verifiedCiphertexts[hash].ciphertext.decrypt()
if err != nil {
t.Fatalf(err.Error())
}
}

func newStopOpcodeContract() *Contract {
addr := AccountRef{}
c := NewContract(addr, addr, big.NewInt(0), 100000)
Expand Down Expand Up @@ -2348,6 +2372,18 @@ func TestDecrypt32(t *testing.T) {
Decrypt(t, FheUint32)
}

func TestFheRand8(t *testing.T) {
FheRand(t, FheUint8)
}

func TestFheRand16(t *testing.T) {
FheRand(t, FheUint16)
}

func TestFheRand32(t *testing.T) {
FheRand(t, FheUint32)
}

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newTestState()
Expand Down Expand Up @@ -2414,3 +2450,52 @@ func TestCiphertextVerificationConditions(t *testing.T) {
t.Fatalf("expected that ciphertext is not verified at verifiedDepth - 1 (%d)", verifiedDepth-1)
}
}

func TestFheRandInvalidInput(t *testing.T) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
_, err := c.Run(state, addr, addr, []byte{}, readOnly)
if err == nil {
t.Fatalf("fheRand expected failure on invalid type")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("fheRand expected 0 verified ciphertexts on invalid input")
}
}

func TestFheRandInvalidType(t *testing.T) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
_, err := c.Run(state, addr, addr, []byte{byte(254)}, readOnly)
if err == nil {
t.Fatalf("fheRand expected failure on invalid type")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("fheRand expected 0 verified ciphertexts on invalid type")
}
}

func TestFheRandEthCall(t *testing.T) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.EthCall = true
addr := common.Address{}
readOnly := true
_, err := c.Run(state, addr, addr, []byte{byte(FheUint8)}, readOnly)
if err == nil {
t.Fatalf("fheRand expected failure on EthCall")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("fheRand expected 0 verified ciphertexts on EthCall")
}
}
14 changes: 13 additions & 1 deletion core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package vm

import (
"bytes"
"encoding/hex"
"errors"
"sync/atomic"
Expand Down Expand Up @@ -674,16 +675,27 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres
interpreter.evm.StateDB.SetState(protectedStorage, val, metadata.serialize())
}

// A list of slots that we consider reserved in protected storage.
// Namely, we won't treat them as ciphertext metadata and we won't garbage collect them.
// TODO: This list will be removed when we change the way we handle ciphertext handles and refcounts.
var reservedProtectedStorageSlots []common.Hash = make([]common.Hash, 0)

// If references are still left, reduce refCount by 1. Otherwise, zero out the metadata and the ciphertext slots.
func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
// If a reserved slot, do not try to garbage collect it.
for _, slot := range reservedProtectedStorageSlots {
if bytes.Equal(metadataKey.Bytes(), slot.Bytes()) {
return
}
}
existingMetadataHash := interpreter.evm.StateDB.GetState(protectedStorage, metadataKey)
existingMetadataInt := newInt(existingMetadataHash.Bytes())
if !existingMetadataInt.IsZero() {
logger := interpreter.evm.Logger
metadata := newCiphertextMetadata(existingMetadataInt.Bytes32())
if metadata.refCount == 1 {
if interpreter.evm.Commit {
logger.Info("opSstore garbage-collecting ciphertext",
logger.Info("opSstore garbage collecting ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"metadataKey", hex.EncodeToString(metadataKey[:]),
"type", metadata.fheUintType,
Expand Down
Loading

0 comments on commit 0fe94b1

Please sign in to comment.