Skip to content

Commit

Permalink
feat(avm-simulator): make storage work across enqueued calls (#6181)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcarreiro authored May 7, 2024
1 parent 151d3a3 commit 8e218a2
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 132 deletions.
15 changes: 14 additions & 1 deletion yarn-project/end-to-end/src/e2e_avm_simulator.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type AccountWallet, AztecAddress, Fr, FunctionSelector, TxStatus } from '@aztec/aztec.js';
import { type AccountWallet, AztecAddress, BatchCall, Fr, FunctionSelector, TxStatus } from '@aztec/aztec.js';
import { GasSettings } from '@aztec/circuits.js';
import {
AvmAcvmInteropTestContract,
Expand Down Expand Up @@ -60,6 +60,19 @@ describe('e2e_avm_simulator', () => {
await avmContract.methods.add_storage_map(address, 100).send().wait();
expect(await avmContract.methods.view_storage_map(address).simulate()).toEqual(200n);
});

it('Preserves storage across enqueued public calls', async () => {
const address = AztecAddress.fromBigInt(9090n);
// This will create 1 tx with 2 public calls in it.
await new BatchCall(wallet, [
avmContract.methods.set_storage_map(address, 100).request(),
avmContract.methods.add_storage_map(address, 100).request(),
])
.send()
.wait();
// On a separate tx, we check the result.
expect(await avmContract.methods.view_storage_map(address).simulate()).toEqual(200n);
});
});

describe('Contract instance', () => {
Expand Down
32 changes: 26 additions & 6 deletions yarn-project/simulator/src/avm/journal/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ export type JournalData = {
newLogsHashes: TracedUnencryptedL2Log[];
/** contract address -\> key -\> value */
currentStorageValue: Map<bigint, Map<bigint, Fr>>;

sideEffectCounter: number;
};

// TRANSITIONAL: This should be removed once the kernel handles and entire enqueued call per circuit
Expand Down Expand Up @@ -143,6 +145,15 @@ export class AvmPersistableStateManager {
this.publicStorage.write(storageAddress, slot, value);

// TRANSITIONAL: This should be removed once the kernel handles and entire enqueued call per circuit
// The current info to the kernel clears any previous read or write request.
this.transitionalExecutionResult.contractStorageReads =
this.transitionalExecutionResult.contractStorageReads.filter(
read => !read.storageSlot.equals(slot) || !read.contractAddress!.equals(storageAddress),
);
this.transitionalExecutionResult.contractStorageUpdateRequests =
this.transitionalExecutionResult.contractStorageUpdateRequests.filter(
update => !update.storageSlot.equals(slot) || !update.contractAddress!.equals(storageAddress),
);
this.transitionalExecutionResult.contractStorageUpdateRequests.push(
new ContractStorageUpdateRequest(slot, value, this.trace.accessCounter, storageAddress),
);
Expand All @@ -159,16 +170,24 @@ export class AvmPersistableStateManager {
* @returns the latest value written to slot, or 0 if never written to before
*/
public async readStorage(storageAddress: Fr, slot: Fr): Promise<Fr> {
const [exists, value] = await this.publicStorage.read(storageAddress, slot);
this.log.debug(`storage(${storageAddress})@${slot} ?? value: ${value}, exists: ${exists}.`);
const { value, exists, cached } = await this.publicStorage.read(storageAddress, slot);
this.log.debug(`storage(${storageAddress})@${slot} ?? value: ${value}, exists: ${exists}, cached: ${cached}.`);

// TRANSITIONAL: This should be removed once the kernel handles and entire enqueued call per circuit
this.transitionalExecutionResult.contractStorageReads.push(
new ContractStorageRead(slot, value, this.trace.accessCounter, storageAddress),
);
// The current info to the kernel kernel does not consider cached reads.
if (!cached) {
// The current info to the kernel removes any previous reads to the same slot.
this.transitionalExecutionResult.contractStorageReads =
this.transitionalExecutionResult.contractStorageReads.filter(
read => !read.storageSlot.equals(slot) || !read.contractAddress!.equals(storageAddress),
);
this.transitionalExecutionResult.contractStorageReads.push(
new ContractStorageRead(slot, value, this.trace.accessCounter, storageAddress),
);
}

// We want to keep track of all performed reads (even reverted ones)
this.trace.tracePublicStorageRead(storageAddress, slot, value, exists);
this.trace.tracePublicStorageRead(storageAddress, slot, value, exists, cached);
return Promise.resolve(value);
}

Expand Down Expand Up @@ -348,6 +367,7 @@ export class AvmPersistableStateManager {
currentStorageValue: this.publicStorage.getCache().cachePerContract,
storageReads: this.trace.publicStorageReads,
storageWrites: this.trace.publicStorageWrites,
sideEffectCounter: this.trace.accessCounter,
};
}
}
29 changes: 20 additions & 9 deletions yarn-project/simulator/src/avm/journal/public_storage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,54 @@ describe('avm public storage', () => {
const contractAddress = new Fr(1);
const slot = new Fr(2);
// never written!
const [exists, gotValue] = await publicStorage.read(contractAddress, slot);
const { exists, value: gotValue, cached } = await publicStorage.read(contractAddress, slot);
// doesn't exist, value is zero
expect(exists).toEqual(false);
expect(gotValue).toEqual(Fr.ZERO);
expect(cached).toEqual(false);
});

it('Should cache storage write, reading works after write', async () => {
const contractAddress = new Fr(1);
const slot = new Fr(2);
const value = new Fr(3);
// Write to cache
publicStorage.write(contractAddress, slot, value);
const [exists, gotValue] = await publicStorage.read(contractAddress, slot);
const { exists, value: gotValue, cached } = await publicStorage.read(contractAddress, slot);
// exists because it was previously written
expect(exists).toEqual(true);
expect(gotValue).toEqual(value);
expect(cached).toEqual(true);
});

it('Reading works on fallback to host (gets value & exists)', async () => {
const contractAddress = new Fr(1);
const slot = new Fr(2);
const storedValue = new Fr(420);
// ensure that fallback to host gets a value
publicDb.storageRead.mockResolvedValue(Promise.resolve(storedValue));

const [exists, gotValue] = await publicStorage.read(contractAddress, slot);
const { exists, value: gotValue, cached } = await publicStorage.read(contractAddress, slot);
// it exists in the host, so it must've been written before
expect(exists).toEqual(true);
expect(gotValue).toEqual(storedValue);
expect(cached).toEqual(false);
});

it('Reading works on fallback to parent (gets value & exists)', async () => {
const contractAddress = new Fr(1);
const slot = new Fr(2);
const value = new Fr(3);
const childStorage = new PublicStorage(publicDb, publicStorage);

publicStorage.write(contractAddress, slot, value);
const [exists, gotValue] = await childStorage.read(contractAddress, slot);
const { exists, value: gotValue, cached } = await childStorage.read(contractAddress, slot);
// exists because it was previously written!
expect(exists).toEqual(true);
expect(gotValue).toEqual(value);
expect(cached).toEqual(true);
});

it('When reading from storage, should check cache, then parent, then host', async () => {
// Store a different value in storage vs the cache, and make sure the cache is returned
const contractAddress = new Fr(1);
Expand All @@ -71,21 +79,24 @@ describe('avm public storage', () => {
const childStorage = new PublicStorage(publicDb, publicStorage);

// Cache miss falls back to host
const [, cacheMissResult] = await childStorage.read(contractAddress, slot);
expect(cacheMissResult).toEqual(storedValue);
const { cached: cachedCacheMiss, value: valueCacheMiss } = await childStorage.read(contractAddress, slot);
expect(valueCacheMiss).toEqual(storedValue);
expect(cachedCacheMiss).toEqual(false);

// Write to storage
publicStorage.write(contractAddress, slot, parentValue);
// Reading from child should give value written in parent
const [, valueFromParent] = await childStorage.read(contractAddress, slot);
const { cached: cachedValueFromParent, value: valueFromParent } = await childStorage.read(contractAddress, slot);
expect(valueFromParent).toEqual(parentValue);
expect(cachedValueFromParent).toEqual(true);

// Now write a value directly in child
childStorage.write(contractAddress, slot, cachedValue);

// Reading should now give the value written in child
const [, cachedResult] = await childStorage.read(contractAddress, slot);
const { cached: cachedChild, value: cachedResult } = await childStorage.read(contractAddress, slot);
expect(cachedResult).toEqual(cachedValue);
expect(cachedChild).toEqual(true);
});
});

Expand All @@ -109,7 +120,7 @@ describe('avm public storage', () => {
publicStorage.acceptAndMerge(childStorage);

// Read from parent gives latest value written in child before merge (valueT1)
const [exists, result] = await publicStorage.read(contractAddress, slot);
const { exists, value: result } = await publicStorage.read(contractAddress, slot);
expect(exists).toEqual(true);
expect(result).toEqual(valueT1);
});
Expand Down
25 changes: 23 additions & 2 deletions yarn-project/simulator/src/avm/journal/public_storage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import { AztecAddress } from '@aztec/circuits.js';
import { Fr } from '@aztec/foundation/fields';

import type { PublicStateDB } from '../../index.js';

type PublicStorageReadResult = {
value: Fr;
exists: boolean;
cached: boolean;
};

/**
* A class to manage public storage reads and writes during a contract call's AVM simulation.
* Maintains a storage write cache, and ensures that reads fall back to the correct source.
Expand Down Expand Up @@ -39,7 +46,8 @@ export class PublicStorage {
* @param slot - the slot in the contract's storage being read from
* @returns exists: whether the slot has EVER been written to before, value: the latest value written to slot, or 0 if never written to before
*/
public async read(storageAddress: Fr, slot: Fr): Promise<[/*exists=*/ boolean, /*value=*/ Fr]> {
public async read(storageAddress: Fr, slot: Fr): Promise<PublicStorageReadResult> {
let cached = false;
// First try check this storage cache
let value = this.cache.read(storageAddress, slot);
// Then try parent's storage cache (if it exists / written to earlier in this TX)
Expand All @@ -49,11 +57,13 @@ export class PublicStorage {
// Finally try the host's Aztec state (a trip to the database)
if (!value) {
value = await this.hostPublicStorage.storageRead(storageAddress, slot);
} else {
cached = true;
}
// if value is undefined, that means this slot has never been written to!
const exists = value !== undefined;
const valueOrZero = exists ? value : Fr.ZERO;
return Promise.resolve([exists, valueOrZero]);
return Promise.resolve({ value: valueOrZero, exists, cached });
}

/**
Expand All @@ -75,6 +85,17 @@ export class PublicStorage {
public acceptAndMerge(incomingPublicStorage: PublicStorage) {
this.cache.acceptAndMerge(incomingPublicStorage.cache);
}

/**
* Commits ALL staged writes to the host's state.
*/
public async commitToDB() {
for (const [storageAddress, cacheAtContract] of this.cache.cachePerContract) {
for (const [slot, value] of cacheAtContract) {
await this.hostPublicStorage.storageWrite(AztecAddress.fromBigInt(storageAddress), new Fr(slot), value);
}
}
}
}

/**
Expand Down
24 changes: 18 additions & 6 deletions yarn-project/simulator/src/avm/journal/trace.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ describe('world state access trace', () => {
let counter = 0;
trace.tracePublicStorageWrite(contractAddress, slot, value);
counter++;
trace.tracePublicStorageRead(contractAddress, slot, value, /*exists=*/ true);
trace.tracePublicStorageRead(contractAddress, slot, value, /*exists=*/ true, /*cached=*/ true);
counter++;
trace.traceNoteHashCheck(contractAddress, noteHash, noteHashExists, noteHashLeafIndex);
counter++;
Expand All @@ -124,7 +124,7 @@ describe('world state access trace', () => {
counter++;
trace.tracePublicStorageWrite(contractAddress, slot, value);
counter++;
trace.tracePublicStorageRead(contractAddress, slot, value, /*exists=*/ true);
trace.tracePublicStorageRead(contractAddress, slot, value, /*exists=*/ true, /*cached=*/ true);
counter++;
trace.traceNewNoteHash(contractAddress, noteHash);
counter++;
Expand Down Expand Up @@ -178,7 +178,7 @@ describe('world state access trace', () => {
};

trace.tracePublicStorageWrite(contractAddress, slot, value);
trace.tracePublicStorageRead(contractAddress, slot, value, /*exists=*/ true);
trace.tracePublicStorageRead(contractAddress, slot, value, /*exists=*/ true, /*cached=*/ true);
trace.traceNoteHashCheck(contractAddress, noteHash, noteHashExists, noteHashLeafIndex);
trace.traceNewNoteHash(contractAddress, noteHash);
trace.traceNullifierCheck(contractAddress, nullifier, nullifierExists, nullifierIsPending, nullifierLeafIndex);
Expand All @@ -187,7 +187,7 @@ describe('world state access trace', () => {

const childTrace = new WorldStateAccessTrace(trace);
childTrace.tracePublicStorageWrite(contractAddress, slot, valueT1);
childTrace.tracePublicStorageRead(contractAddress, slot, valueT1, /*exists=*/ true);
childTrace.tracePublicStorageRead(contractAddress, slot, valueT1, /*exists=*/ true, /*cached=*/ true);
childTrace.traceNoteHashCheck(contractAddress, noteHashT1, noteHashExistsT1, noteHashLeafIndexT1);
childTrace.traceNewNoteHash(contractAddress, nullifierT1);
childTrace.traceNullifierCheck(
Expand All @@ -205,8 +205,20 @@ describe('world state access trace', () => {
expect(trace.getAccessCounter()).toEqual(childCounterBeforeMerge);

expect(trace.publicStorageReads).toEqual([
expect.objectContaining({ storageAddress: contractAddress, slot: slot, value: value, exists: true }),
expect.objectContaining({ storageAddress: contractAddress, slot: slot, value: valueT1, exists: true }),
expect.objectContaining({
storageAddress: contractAddress,
slot: slot,
value: value,
exists: true,
cached: true,
}),
expect.objectContaining({
storageAddress: contractAddress,
slot: slot,
value: valueT1,
exists: true,
cached: true,
}),
]);
expect(trace.publicStorageWrites).toEqual([
expect.objectContaining({ storageAddress: contractAddress, slot: slot, value: value }),
Expand Down
3 changes: 2 additions & 1 deletion yarn-project/simulator/src/avm/journal/trace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export class WorldStateAccessTrace {
return this.accessCounter;
}

public tracePublicStorageRead(storageAddress: Fr, slot: Fr, value: Fr, exists: boolean) {
public tracePublicStorageRead(storageAddress: Fr, slot: Fr, value: Fr, exists: boolean, cached: boolean) {
// TODO(4805): check if some threshold is reached for max storage reads
// (need access to parent length, or trace needs to be initialized with parent's contents)
const traced: TracedPublicStorageRead = {
Expand All @@ -45,6 +45,7 @@ export class WorldStateAccessTrace {
slot,
value,
exists,
cached,
counter: new Fr(this.accessCounter),
// endLifetime: Fr.ZERO,
};
Expand Down
1 change: 1 addition & 0 deletions yarn-project/simulator/src/avm/journal/trace_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export type TracedPublicStorageRead = {
// callPointer: Fr;
storageAddress: Fr;
exists: boolean;
cached: boolean;
slot: Fr;
value: Fr;
counter: Fr;
Expand Down
18 changes: 14 additions & 4 deletions yarn-project/simulator/src/public/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export async function executePublicFunction(
}

if (isAvmBytecode(bytecode)) {
return await executeTopLevelPublicFunctionAvm(context);
return await executeTopLevelPublicFunctionAvm(context, bytecode);
} else {
return await executePublicFunctionAcvm(context, bytecode, nested);
}
Expand All @@ -58,6 +58,7 @@ export async function executePublicFunction(
*/
async function executeTopLevelPublicFunctionAvm(
executionContext: PublicExecutionContext,
bytecode: Buffer,
): Promise<PublicExecutionResult> {
const address = executionContext.execution.contractAddress;
const selector = executionContext.execution.functionData.selector;
Expand Down Expand Up @@ -91,16 +92,25 @@ async function executeTopLevelPublicFunctionAvm(
const avmContext = new AvmContext(worldStateJournal, executionEnv, machineState);
const simulator = new AvmSimulator(avmContext);

const avmResult = await simulator.execute();
const avmResult = await simulator.executeBytecode(bytecode);

// Commit the journals state to the DBs since this is a top-level execution.
// Observe that this will write all the state changes to the DBs, not only the latest for each slot.
// However, the underlying DB keep a cache and will only write the latest state to disk.
await avmContext.persistableState.publicStorage.commitToDB();

log.verbose(
`[AVM] ${address.toString()}:${selector} returned, reverted: ${avmResult.reverted}, reason: ${
avmResult.revertReason
}.`,
);

return Promise.resolve(
convertAvmResultsToPxResult(avmResult, startSideEffectCounter, executionContext.execution, startGas, avmContext),
return convertAvmResultsToPxResult(
avmResult,
startSideEffectCounter,
executionContext.execution,
startGas,
avmContext,
);
}

Expand Down
Loading

0 comments on commit 8e218a2

Please sign in to comment.