Skip to content

Commit

Permalink
Reinstate plaintext fheRand with type support
Browse files Browse the repository at this point in the history
Use ChaCha20 for PRNG. Have a separate PRNG per contract by seeding it
with a global plaintext seed as:
`contractSeed = Keccack256(globalSeed || contractAddress)`.
Also, use a counter as a nonce that is persisted in the contract's
protected storage, ensuring every contract has its own nonce.

Can only be called in transactions. Calling it in view functions
(i.e. EthCall) will fail.

Make sure we don't garbage collect the nonce slot (slot 0) in protected
storage by defining it as a reserved slot. That is a temporary solution
that we will revise soon by only running garbage collection on actual
ciphertext handles.
  • Loading branch information
dartdart26 committed Sep 6, 2023
1 parent cc574ae commit 84487ce
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 85 deletions.
193 changes: 111 additions & 82 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/ethereum/go-ethereum/params"
"github.com/holiman/uint256"
"github.com/naoina/toml"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/nacl/box"
"golang.org/x/crypto/ripemd160"
)
Expand Down Expand Up @@ -75,7 +76,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -119,7 +120,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -164,7 +165,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -209,7 +210,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -254,7 +255,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
Expand Down Expand Up @@ -1471,6 +1472,12 @@ var fheTrivialEncryptGasCosts = map[fheUintType]uint64{
FheUint32: params.FheUint32TrivialEncryptGas,
}

var fheRandGasCosts = map[fheUintType]uint64{
FheUint8: params.FheUint8RandGas,
FheUint16: params.FheUint16RandGas,
FheUint32: params.FheUint32RandGas,
}

func writeResult(ct *tfheCiphertext, fileName string, logger Logger) {
os.WriteFile("/tmp/"+fileName, ct.serialize(), 0644)
}
Expand Down Expand Up @@ -3317,83 +3324,105 @@ func (e *fheNot) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return resultHash[:], nil
}

// type fheRand struct{}

// var globalRngSeed []byte

// var rngNonceKey [32]byte = uint256.NewInt(0).Bytes32()

// func init() {
// if chacha20.NonceSizeX != 24 {
// panic("expected 24 bytes for NonceSizeX")
// }

// // TODO: Since the current implementation is not FHE-based and, hence, not private,
// // we just initialize the global seed with non-random public data. We will change
// // that once the FHE version is available.
// globalRngSeed = make([]byte, chacha20.KeySize)
// for i := range globalRngSeed {
// globalRngSeed[i] = byte(1 + i)
// }
// }

// func (e *fheRand) RequiredGas(input []byte) uint64 {
// // TODO
// return 8
// }

// func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
// // If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
// if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
// return importRandomCiphertext(accessibleState), nil
// }

// // Get the RNG nonce.
// protectedStorage := crypto.CreateProtectedStorageContractAddress(caller)
// currentRngNonceBytes := accessibleState.Interpreter().evm.StateDB.GetState(protectedStorage, rngNonceKey).Bytes()

// // Increment the RNG nonce by 1.
// nextRngNonce := newInt(currentRngNonceBytes)
// nextRngNonce = nextRngNonce.AddUint64(nextRngNonce, 1)
// accessibleState.Interpreter().evm.StateDB.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32())

// // Compute the seed and use it to create a new cipher.
// hasher := crypto.NewKeccakState()
// hasher.Write(globalRngSeed)
// hasher.Write(caller.Bytes())
// hasher.Write(currentRngNonceBytes)
// seed := common.Hash{}
// _, err := hasher.Read(seed[:])
// if err != nil {
// return nil, err
// }
// // The RNG nonce bytes are of size chacha20.NonceSizeX, which is assumed to be 24 bytes (see init() above).
// // Since uint256.Int.z[0] is the least significant byte and since uint256.Int.Bytes32() serializes
// // in order of z[3], z[2], z[1], z[0], we want to essentially ignore the first byte, i.e. z[3], because
// // it will always be 0 as the nonce size is 24.
// cipher, err := chacha20.NewUnauthenticatedCipher(seed.Bytes(), currentRngNonceBytes[32-chacha20.NonceSizeX:32])
// if err != nil {
// return nil, err
// }

// // XOR a byte array of 0s with the stream from the cipher and receive the result in the same array.
// randBytes := make([]byte, 8)
// cipher.XORKeyStream(randBytes, randBytes)

// // Trivially encrypt the random integer.
// randInt := binary.BigEndian.Uint64(randBytes) % math.BigPow(2, 3).Uint64()
// randCt := new(tfheCiphertext)
// randCt.trivialEncrypt(randInt)
// importCiphertext(accessibleState, randCt)

// // TODO: for testing
// err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644)
// if err != nil {
// return nil, err
// }
// ctHash := randCt.getHash()
// return ctHash[:], nil
// }
type fheRand struct{}

var globalRngSeed []byte

var rngNonceKey [32]byte = uint256.NewInt(0).Bytes32()

func init() {
if chacha20.NonceSizeX != 24 {
panic("expected 24 bytes for NonceSizeX")
}

// TODO: Since the current implementation is not FHE-based and, hence, not private,
// we just initialize the global seed with non-random public data. We will change
// that once the FHE version is available.
globalRngSeed = make([]byte, chacha20.KeySize)
for i := range globalRngSeed {
globalRngSeed[i] = byte(1 + i)
}

// Make sure we mark the RNG nonce key as a reserved slot in protected storage.
reservedProtectedStorageSlots = append(reservedProtectedStorageSlots, common.BytesToHash(rngNonceKey[:]))
}

func (e *fheRand) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
logger := accessibleState.Interpreter().evm.Logger
if len(input) != 1 || !isValidType(input[0]) {
logger.Error("fheRand RequiredGas() input len must be at least 1 byte and be a valid FheUint type", "input", hex.EncodeToString(input), "len", len(input))
return 0
}
t := fheUintType(input[0])
return fheRandGasCosts[t]
}

func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := accessibleState.Interpreter().evm.Logger
if accessibleState.Interpreter().evm.EthCall {
msg := "fheRand cannot be called via EthCall, because it needs to mutate internal state"
logger.Error(msg)
return nil, errors.New(msg)
}
if len(input) != 1 || !isValidType(input[0]) {
msg := "fheRand input len must be at least 1 byte and be a valid FheUint type"
logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input))
return nil, errors.New(msg)
}

t := fheUintType(input[0])
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
return importRandomCiphertext(accessibleState, t), nil
}

// Get the RNG nonce.
protectedStorage := crypto.CreateProtectedStorageContractAddress(caller)
currentRngNonceBytes := accessibleState.Interpreter().evm.StateDB.GetState(protectedStorage, rngNonceKey).Bytes()

// Increment the RNG nonce by 1.
nextRngNonce := newInt(currentRngNonceBytes)
nextRngNonce = nextRngNonce.AddUint64(nextRngNonce, 1)
accessibleState.Interpreter().evm.StateDB.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32())

// Compute the seed and use it to create a new cipher.
hasher := crypto.NewKeccakState()
hasher.Write(globalRngSeed)
hasher.Write(caller.Bytes())
seed := common.Hash{}
_, err := hasher.Read(seed[:])
if err != nil {
return nil, err
}
// The RNG nonce bytes are of size chacha20.NonceSizeX, which is assumed to be 24 bytes (see init() above).
// Since uint256.Int.z[0] is the least significant byte and since uint256.Int.Bytes32() serializes
// in order of z[3], z[2], z[1], z[0], we want to essentially ignore the first byte, i.e. z[3], because
// it will always be 0 as the nonce size is 24.
cipher, err := chacha20.NewUnauthenticatedCipher(seed.Bytes(), currentRngNonceBytes[32-chacha20.NonceSizeX:32])
if err != nil {
return nil, err
}

// XOR a byte array of 0s with the stream from the cipher and receive the result in the same array.
randBytes := make([]byte, 8)
cipher.XORKeyStream(randBytes, randBytes)

// Trivially encrypt the random integer.
randUint64 := binary.BigEndian.Uint64(randBytes)
randCt := new(tfheCiphertext)
randBigInt := big.NewInt(0)
randBigInt.SetUint64(randUint64)
randCt.trivialEncrypt(*randBigInt, t)
importCiphertext(accessibleState, randCt)

// TODO: for testing
err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644)
if err != nil {
return nil, err
}
ctHash := randCt.getHash()
return ctHash[:], nil
}

type cast struct{}

Expand Down
85 changes: 85 additions & 0 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,30 @@ func Decrypt(t *testing.T, fheUintType fheUintType) {
}
}

func FheRand(t *testing.T, fheUintType fheUintType) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
out, err := c.Run(state, addr, addr, []byte{byte(fheUintType)}, readOnly)
if err != nil {
t.Fatalf(err.Error())
} else if len(out) != 32 {
t.Fatalf("fheRand expected output len of 32, got %v", len(out))
}
if len(state.interpreter.verifiedCiphertexts) != 1 {
t.Fatalf("fheRand expected 1 verified ciphertext")
}

hash := common.BytesToHash(out)
_, err = state.interpreter.verifiedCiphertexts[hash].ciphertext.decrypt()
if err != nil {
t.Fatalf(err.Error())
}
}

func newStopOpcodeContract() *Contract {
addr := AccountRef{}
c := NewContract(addr, addr, big.NewInt(0), 100000)
Expand Down Expand Up @@ -2348,6 +2372,18 @@ func TestDecrypt32(t *testing.T) {
Decrypt(t, FheUint32)
}

func TestFheRand8(t *testing.T) {
FheRand(t, FheUint8)
}

func TestFheRand16(t *testing.T) {
FheRand(t, FheUint16)
}

func TestFheRand32(t *testing.T) {
FheRand(t, FheUint32)
}

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newTestState()
Expand Down Expand Up @@ -2414,3 +2450,52 @@ func TestCiphertextVerificationConditions(t *testing.T) {
t.Fatalf("expected that ciphertext is not verified at verifiedDepth - 1 (%d)", verifiedDepth-1)
}
}

func TestFheRandInvalidInput(t *testing.T) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
_, err := c.Run(state, addr, addr, []byte{}, readOnly)
if err == nil {
t.Fatalf("fheRand expected failure on invalid type")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("fheRand expected 0 verified ciphertexts on invalid input")
}
}

func TestFheRandInvalidType(t *testing.T) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
_, err := c.Run(state, addr, addr, []byte{byte(254)}, readOnly)
if err == nil {
t.Fatalf("fheRand expected failure on invalid type")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("fheRand expected 0 verified ciphertexts on invalid type")
}
}

func TestFheRandEthCall(t *testing.T) {
c := &fheRand{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.EthCall = true
addr := common.Address{}
readOnly := true
_, err := c.Run(state, addr, addr, []byte{byte(FheUint8)}, readOnly)
if err == nil {
t.Fatalf("fheRand expected failure on EthCall")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("fheRand expected 0 verified ciphertexts on EthCall")
}
}
14 changes: 13 additions & 1 deletion core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package vm

import (
"bytes"
"encoding/hex"
"errors"
"sync/atomic"
Expand Down Expand Up @@ -674,16 +675,27 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres
interpreter.evm.StateDB.SetState(protectedStorage, val, metadata.serialize())
}

// A list of slots that we consider reserved in protected storage.
// Namely, we won't treat them as ciphertext metadata and we won't garbage collect them.
// TODO: This list will be removed when we change the way we handle ciphertext handles and refcounts.
var reservedProtectedStorageSlots []common.Hash = make([]common.Hash, 0)

// If references are still left, reduce refCount by 1. Otherwise, zero out the metadata and the ciphertext slots.
func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
// If a reserved slot, do not try to garbage collect it.
for _, slot := range reservedProtectedStorageSlots {
if bytes.Equal(metadataKey.Bytes(), slot.Bytes()) {
return
}
}
existingMetadataHash := interpreter.evm.StateDB.GetState(protectedStorage, metadataKey)
existingMetadataInt := newInt(existingMetadataHash.Bytes())
if !existingMetadataInt.IsZero() {
logger := interpreter.evm.Logger
metadata := newCiphertextMetadata(existingMetadataInt.Bytes32())
if metadata.refCount == 1 {
if interpreter.evm.Commit {
logger.Info("opSstore garbage-collecting ciphertext",
logger.Info("opSstore garbage collecting ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"metadataKey", hex.EncodeToString(metadataKey[:]),
"type", metadata.fheUintType,
Expand Down
Loading

0 comments on commit 84487ce

Please sign in to comment.