Skip to content

Commit

Permalink
Merge pull request #171 from zama-ai/petar/gc-attack-mitigation
Browse files Browse the repository at this point in the history
Mitigate attacks on garbage collection
  • Loading branch information
dartdart26 authored Oct 3, 2023
2 parents 8181b69 + 3cb61ec commit 328447d
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 145 deletions.
27 changes: 17 additions & 10 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3619,9 +3619,6 @@ func init() {
for i := range globalRngSeed {
globalRngSeed[i] = byte(1 + i)
}

// Make sure we mark the RNG nonce key as a reserved slot in protected storage.
reservedProtectedStorageSlots = append(reservedProtectedStorageSlots, common.BytesToHash(rngNonceKey[:]))
}

func (e *fheRand) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
Expand Down Expand Up @@ -3710,6 +3707,11 @@ func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []by
"len", len(input))
return 0
}
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if ct == nil {
accessibleState.Interpreter().evm.Logger.Error("cast input not verified")
return 0
}
return params.FheCastGas
}

Expand All @@ -3722,18 +3724,18 @@ func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Addr
return nil, errors.New(msg)
}

ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("cast input not verified")
return nil, errors.New("unverified ciphertext handle")
}

if !isValidType(input[32]) {
logger.Error("invalid type to cast to")
return nil, errors.New("invalid type provided")
}
castToType := fheUintType(input[32])

ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("cast input not verified")
return nil, errors.New("unverified ciphertext handle")
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState, castToType), nil
Expand Down Expand Up @@ -3865,8 +3867,13 @@ func (e *trivialEncrypt) Run(accessibleState PrecompileAccessibleState, caller c
return nil, errors.New(msg)
}

valueToEncrypt := *new(big.Int).SetBytes(input[0:32])
if !isValidType(input[32]) {
msg := "trivialEncrypt ciphertext type is invalid"
logger.Error(msg, "type", input[32])
return nil, errors.New(msg)
}
encryptToType := fheUintType(input[32])
valueToEncrypt := *new(big.Int).SetBytes(input[0:32])

ct := new(tfheCiphertext).trivialEncrypt(valueToEncrypt, encryptToType)

Expand Down
34 changes: 34 additions & 0 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2957,6 +2957,40 @@ func TestVerifyCiphertextInvalidType(t *testing.T) {
}
}

func TestTrivialEncryptInvalidType(t *testing.T) {
c := &trivialEncrypt{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
invalidType := fheUintType(255)
input := make([]byte, 32)
input = append(input, byte(invalidType))
_, err := c.Run(state, addr, addr, input, readOnly)
if err == nil {
t.Fatalf("trivialEncrypt must have failed on invalid ciphertext type")
}
}

func TestCastInvalidType(t *testing.T) {
c := &cast{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
invalidType := fheUintType(255)
hash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
input := make([]byte, 0)
input = append(input, hash.Bytes()...)
input = append(input, byte(invalidType))
_, err := c.Run(state, addr, addr, input, readOnly)
if err == nil {
t.Fatalf("cast must have failed on invalid ciphertext type")
}
}

func TestVerifyCiphertextInvalidSize(t *testing.T) {
c := &verifyCiphertext{}
depth := 1
Expand Down
99 changes: 57 additions & 42 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,10 +559,10 @@ func newInt(buf []byte) *uint256.Int {
return i.SetBytes(buf)
}

var zero = uint256.NewInt(0).Bytes32()
var zero = common.BytesToHash(uint256.NewInt(0).Bytes())

func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, contractAddress common.Address) error {
ct, ok := interpreter.verifiedCiphertexts[val]
func verifyIfCiphertextHandle(handle common.Hash, interpreter *EVMInterpreter, contractAddress common.Address) error {
ct, ok := interpreter.verifiedCiphertexts[handle]
if ok {
// If already existing in memory, skip storage and import the same ciphertext at the current depth.
//
Expand All @@ -573,18 +573,14 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont
return nil
}

// If a reserved slot, do not try treat it as ciphertext metadata.
if isReservedSlot(val) {
return nil
}

metadataKey := crypto.Keccak256Hash(handle.Bytes())
protectedStorage := crypto.CreateProtectedStorageContractAddress(contractAddress)
metadataInt := newInt(interpreter.evm.StateDB.GetState(protectedStorage, val).Bytes())
metadataInt := newInt(interpreter.evm.StateDB.GetState(protectedStorage, metadataKey).Bytes())
if !metadataInt.IsZero() {
metadata := newCiphertextMetadata(metadataInt.Bytes32())
ctBytes := make([]byte, 0)
left := metadata.length
protectedSlotIdx := newInt(val.Bytes())
protectedSlotIdx := newInt(metadataKey.Bytes())
protectedSlotIdx.AddUint64(protectedSlotIdx, 1)
for {
if left == 0 {
Expand Down Expand Up @@ -620,29 +616,38 @@ func opSload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]by
return nil, nil
}

// An arbitrary constant value to flag locations in protected storage.
var flag = common.HexToHash("0xa145ffde0100a145ffde0100a145ffde0100a145ffde0100a145ffde0100fab3")

// If a verified ciphertext:
// * if the ciphertext does not exist in protected storage, persist it with a refCount = 1
// * if the ciphertexts exists in protected, bump its refCount by 1
func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
verifiedCiphertext := getVerifiedCiphertextFromEVM(interpreter, val)
func persistIfVerifiedCiphertext(flagHandleLocation common.Hash, handle common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
verifiedCiphertext := getVerifiedCiphertextFromEVM(interpreter, handle)
if verifiedCiphertext == nil {
return
}
logger := interpreter.evm.Logger

// Try to read ciphertext metadata from protected storage.
metadataInt := newInt(interpreter.evm.StateDB.GetState(protectedStorage, val).Bytes())
metadataKey := crypto.Keccak256Hash(handle.Bytes())
metadataInt := newInt(interpreter.evm.StateDB.GetState(protectedStorage, metadataKey).Bytes())
metadata := ciphertextMetadata{}

// Set flag in protected storage to mark the location as containing a handle.
interpreter.evm.StateDB.SetState(protectedStorage, flagHandleLocation, flag)

if metadataInt.IsZero() {
// If no metadata, it means this ciphertext itself hasn't been persisted to protected storage yet. We do that as part of SSTORE.
metadata.refCount = 1
metadata.length = uint64(expandedFheCiphertextSize[verifiedCiphertext.ciphertext.fheUintType])
metadata.fheUintType = verifiedCiphertext.ciphertext.fheUintType
ciphertextSlot := newInt(val.Bytes())
ciphertextSlot := newInt(metadataKey.Bytes())
ciphertextSlot.AddUint64(ciphertextSlot, 1)
if interpreter.evm.Commit {
logger.Info("opSstore persisting new ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"handle", hex.EncodeToString(val.Bytes()),
"handle", hex.EncodeToString(handle.Bytes()),
"type", metadata.fheUintType,
"len", metadata.length,
"ciphertextSlot", hex.EncodeToString(ciphertextSlot.Bytes()))
Expand All @@ -665,45 +670,46 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres
}
} else {
// If metadata exists, bump the refcount by 1.
metadata = *newCiphertextMetadata(interpreter.evm.StateDB.GetState(protectedStorage, val))
metadata = *newCiphertextMetadata(interpreter.evm.StateDB.GetState(protectedStorage, metadataKey))
metadata.refCount++
if interpreter.evm.Commit {
logger.Info("opSstore bumping refcount of existing ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"handle", hex.EncodeToString(val.Bytes()),
"handle", hex.EncodeToString(handle.Bytes()),
"type", metadata.fheUintType,
"len", metadata.length,
"refCount", metadata.refCount)
}
}
// Save the metadata in protected storage.
interpreter.evm.StateDB.SetState(protectedStorage, val, metadata.serialize())
}

// A list of slots that we consider reserved in protected storage.
// Namely, we won't treat them as ciphertext metadata and we won't garbage collect them.
// TODO: This list will be removed when we change the way we handle ciphertext handles and refcounts.
var reservedProtectedStorageSlots []common.Hash = make([]common.Hash, 0)

func isReservedSlot(key common.Hash) bool {
for _, slot := range reservedProtectedStorageSlots {
if bytes.Equal(key.Bytes(), slot.Bytes()) {
return true
}
}
return false
interpreter.evm.StateDB.SetState(protectedStorage, metadataKey, metadata.serialize())
}

// If references are still left, reduce refCount by 1. Otherwise, zero out the metadata and the ciphertext slots.
func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
// If a reserved slot, do not try to garbage collect it.
if isReservedSlot(metadataKey) {
return
}
func garbageCollectProtectedStorage(flagHandleLocation common.Hash, handle common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
// The location of ciphertext metadata is at Keccak256(handle). Doing so avoids attacks from users trying to garbage
// collect arbitrary locations in protected storage. Hashing the handle makes it hard to find a preimage such that
// it ends up in arbitrary non-zero places in protected stroage.
metadataKey := crypto.Keccak256Hash(handle.Bytes())

existingMetadataHash := interpreter.evm.StateDB.GetState(protectedStorage, metadataKey)
existingMetadataInt := newInt(existingMetadataHash.Bytes())
if !existingMetadataInt.IsZero() {
logger := interpreter.evm.Logger

// If no flag in protected storage for the location, ignore garbage collection.
// Else, set the value at the location to zero.
foundFlag := interpreter.evm.StateDB.GetState(protectedStorage, flagHandleLocation)
if !bytes.Equal(foundFlag.Bytes(), flag.Bytes()) {
logger.Error("opSstore location flag not found for a ciphertext handle, ignoring garbage collection",
"expectedFlag", hex.EncodeToString(flag[:]),
"foundFlag", hex.EncodeToString(foundFlag[:]),
"flagHandleLocation", hex.EncodeToString(flagHandleLocation[:]))
return
} else {
interpreter.evm.StateDB.SetState(protectedStorage, flagHandleLocation, zero)
}

metadata := newCiphertextMetadata(existingMetadataInt.Bytes32())
if metadata.refCount == 1 {
if interpreter.evm.Commit {
Expand Down Expand Up @@ -749,17 +755,26 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b
return nil, ErrWriteProtection
}
loc := scope.Stack.pop()
locHash := common.BytesToHash(loc.Bytes())
newVal := scope.Stack.pop()
newValBytes := newVal.Bytes()
newValHash := common.BytesToHash(newValBytes)
newValHash := common.BytesToHash(newVal.Bytes())
oldValHash := interpreter.evm.StateDB.GetState(scope.Contract.Address(), common.Hash(loc.Bytes32()))
protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address())
// If the value is the same or if we are not going to commit, don't do anything to protected storage.
if newValHash != oldValHash && interpreter.evm.Commit {
protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address())

// Define flag location as keccak256(keccak256(loc)) in protected storage. Used to mark the location as containing a handle.
// Note: We apply the hash function twice to make sure a flag location in protected storage cannot clash with a ciphertext
// metadata location that is keccak256(keccak256(ciphertext)). Since a location is 32 bytes, it cannot clash with a well-formed
// ciphertext. Therefore, there needs to be a hash collistion for a clash to happen. If hash is applied only once, there could
// be a collision, since malicous users could store at loc = keccak256(ciphertext), making the flag clash with metadata.
flagHandleLocation := crypto.Keccak256Hash(crypto.Keccak256Hash(locHash[:]).Bytes())

// Since the old value is no longer stored in actual contract storage, run garbage collection on protected storage.
garbageCollectProtectedStorage(oldValHash, protectedStorage, interpreter)
garbageCollectProtectedStorage(flagHandleLocation, oldValHash, protectedStorage, interpreter)

// If a verified ciphertext, persist to protected storage.
persistIfVerifiedCiphertext(newValHash, protectedStorage, interpreter)
persistIfVerifiedCiphertext(flagHandleLocation, newValHash, protectedStorage, interpreter)
}
// Set the SSTORE's value in the actual contract.
interpreter.evm.StateDB.SetState(scope.Contract.Address(),
Expand Down
Loading

0 comments on commit 328447d

Please sign in to comment.