Skip to content

Commit

Permalink
Merge pull request #167 from zama-ai/petar/check-reserved-protected-o…
Browse files Browse the repository at this point in the history
…n-sload

Check reserved slots in protected storage on SLOAD
  • Loading branch information
dartdart26 committed Sep 19, 2023
2 parents a9bd207 + c6824b2 commit 52c7d94
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 5 deletions.
20 changes: 16 additions & 4 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,11 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont
return nil
}

// If a reserved slot, do not try treat it as ciphertext metadata.
if isReservedSlot(val) {
return nil
}

protectedStorage := crypto.CreateProtectedStorageContractAddress(contractAddress)
metadataInt := newInt(interpreter.evm.StateDB.GetState(protectedStorage, val).Bytes())
if !metadataInt.IsZero() {
Expand Down Expand Up @@ -680,13 +685,20 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres
// TODO: This list will be removed when we change the way we handle ciphertext handles and refcounts.
var reservedProtectedStorageSlots []common.Hash = make([]common.Hash, 0)

func isReservedSlot(key common.Hash) bool {
for _, slot := range reservedProtectedStorageSlots {
if bytes.Equal(key.Bytes(), slot.Bytes()) {
return true
}
}
return false
}

// If references are still left, reduce refCount by 1. Otherwise, zero out the metadata and the ciphertext slots.
func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage common.Address, interpreter *EVMInterpreter) {
// If a reserved slot, do not try to garbage collect it.
for _, slot := range reservedProtectedStorageSlots {
if bytes.Equal(metadataKey.Bytes(), slot.Bytes()) {
return
}
if isReservedSlot(metadataKey) {
return
}
existingMetadataHash := interpreter.evm.StateDB.GetState(protectedStorage, metadataKey)
existingMetadataInt := newInt(existingMetadataHash.Bytes())
Expand Down
94 changes: 93 additions & 1 deletion core/vm/instructions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,98 @@ func TestProtectedStorageGarbageCollection(t *testing.T) {
}
}

func TestProtectedStorageGarbageCollectionOnReservedSlot(t *testing.T) {
scope := newTestScopeConext()
protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address())
interpreter := newTestInterpreter()
pc := uint64(0)
depth := 1
interpreter.evm.depth = depth

// Simulate metadata for a ciphertext at a reserved protected storage slot.
metadata := ciphertextMetadata{}
metadata.fheUintType = FheUint8
metadata.refCount = 1
metadata.length = 3
metadataSer := metadata.serialize()
metadataSlot := reservedProtectedStorageSlots[0]
interpreter.evm.StateDB.SetState(protectedStorage, metadataSlot, common.BytesToHash(metadataSer[:]))

// Simulate a ciphertext in protected storage.
slot := uint256FromBig(metadataSlot.Big())
nonZero := common.Hash{}
nonZero[0] = 1
for i := uint64(1); i < metadata.length+1; i++ {
slot = slot.AddUint64(slot, i)
interpreter.evm.StateDB.SetState(protectedStorage, common.BytesToHash(slot.Bytes()), nonZero)
}

// Simulate SSTORE with a new value that is different the reserved slot.
valueHash := metadataSlot
valueHash[0]++
loc := uint256.NewInt(10)
value := uint256FromBig(valueHash.Big())

// Call SSTORE.
scope.Stack.push(value)
scope.Stack.push(loc)
_, err := opSstore(&pc, interpreter, scope)
if err != nil {
t.Fatalf(err.Error())
}

// Verify that garbage collection hasn't happened for a reserved protected storage slot.
slot = uint256FromBig(metadataSlot.Big())
for i := uint64(0); i < metadata.length+1; i++ {
slot = slot.AddUint64(slot, i)
res := interpreter.evm.StateDB.GetState(protectedStorage, common.BytesToHash(slot.Bytes()))
if bytes.Equal(res.Bytes(), common.Hash{}.Bytes()) {
t.Fatalf("garbage collection must not have happened")
}
}
}

func TestProtectedStorageSloadOnReservedSlot(t *testing.T) {
scope := newTestScopeConext()
interpreter := newTestInterpreter()
pc := uint64(0)
depth := 1
interpreter.evm.depth = depth

handle := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8).getHash()
loc := uint256.NewInt(10)
value := uint256FromBig(handle.Big())

// Consider the returned handle as a reserved slot.
reservedProtectedStorageSlots = append(reservedProtectedStorageSlots, handle)

// Persist the ciphertext in protected storage.
scope.Stack.push(value)
scope.Stack.push(loc)
_, err := opSstore(&pc, interpreter, scope)
if err != nil {
t.Fatalf(err.Error())
}

// Clear verified ciphertexts.
interpreter.verifiedCiphertexts = make(map[common.Hash]*verifiedCiphertext)

// Call SLOAD.
scope.Stack.push(loc)
_, err = opSload(&pc, interpreter, scope)
if err != nil {
t.Fatalf(err.Error())
}

// Remove the handle from reserved slots.
reservedProtectedStorageSlots = reservedProtectedStorageSlots[:len(reservedProtectedStorageSlots)-1]

// Expect no verified ciphertexts.
if len(interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("expected no verified ciphetexts")
}
}

func TestProtectedStorageSloadDoesNotVerifyNonHandle(t *testing.T) {
pc := uint64(0)
interpreter := newTestInterpreter()
Expand All @@ -857,7 +949,7 @@ func TestProtectedStorageSloadDoesNotVerifyNonHandle(t *testing.T) {
t.Fatalf(err.Error())
}

// Expect no verified ciphertexts
// Expect no verified ciphertexts.
if len(interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("expected no verified ciphetexts")
}
Expand Down

0 comments on commit 52c7d94

Please sign in to comment.