diff --git a/rpc/trace_test.go b/rpc/trace_test.go index 75d1f75c..4707be40 100644 --- a/rpc/trace_test.go +++ b/rpc/trace_test.go @@ -32,14 +32,14 @@ func TestTransactionTrace(t *testing.T) { type testSetType struct { TransactionHash *felt.Felt - ExpectedResp *InvokeTxnTrace - ExpectedError *RPCError + ExpectedResp TxnTrace + ExpectedError error } testSet := map[string][]testSetType{ "mock": { testSetType{ TransactionHash: utils.TestHexToFelt(t, "0x6a4a9c4f1a530f7d6dd7bba9b71f090a70d1e3bbde80998fde11a08aab8b282"), - ExpectedResp: &expectedResp, + ExpectedResp: expectedResp, ExpectedError: nil, }, testSetType{ @@ -61,7 +61,7 @@ func TestTransactionTrace(t *testing.T) { "testnet": { testSetType{ TransactionHash: utils.TestHexToFelt(t, "0x6a4a9c4f1a530f7d6dd7bba9b71f090a70d1e3bbde80998fde11a08aab8b282"), - ExpectedResp: &expectedResp, + ExpectedResp: expectedResp, ExpectedError: nil, }, }, @@ -70,12 +70,8 @@ func TestTransactionTrace(t *testing.T) { for _, test := range testSet { resp, err := testConfig.provider.TraceTransaction(context.Background(), test.TransactionHash) - if err != nil { - require.Equal(t, test.ExpectedError, err) - } else { - invokeTrace := resp.(InvokeTxnTrace) - require.Equal(t, invokeTrace, *test.ExpectedResp) - } + require.Equal(t, test.ExpectedError, err) + compareTraceTxs(t, test.ExpectedResp, resp) } } @@ -144,8 +140,11 @@ func TestSimulateTransaction(t *testing.T) { test.SimulateTxnInput.Txns, test.SimulateTxnInput.SimulationFlags) require.NoError(t, err) - require.Equal(t, test.ExpectedResp.Txns[0].FeeEstimate, resp[0].FeeEstimate) - require.Len(t, test.ExpectedResp.Txns, len(resp)) + + for i, trace := range resp { + require.Equal(t, test.ExpectedResp.Txns[i].FeeEstimate, trace.FeeEstimate) + compareTraceTxs(t, test.ExpectedResp.Txns[i].TxnTrace, trace.TxnTrace) + } } } @@ -163,12 +162,13 @@ func TestSimulateTransaction(t *testing.T) { // none func TestTraceBlockTransactions(t *testing.T) { testConfig := beforeEach(t) + require := require.New(t) var blockTraceSepolia []Trace expectedrespRaw, err := os.ReadFile("./tests/trace/sepoliaBlockTrace_0x42a4c6a4c3dffee2cce78f04259b499437049b0084c3296da9fbbec7eda79b2.json") - require.NoError(t, err, "Error ReadFile for TestTraceBlockTransactions") - require.NoError(t, json.Unmarshal(expectedrespRaw, &blockTraceSepolia), "Error unmarshalling testdata TestTraceBlockTransactions") + require.NoError(err, "Error ReadFile for TestTraceBlockTransactions") + require.NoError(json.Unmarshal(expectedrespRaw, &blockTraceSepolia), "Error unmarshalling testdata TestTraceBlockTransactions") type testSetType struct { BlockID BlockID @@ -178,12 +178,12 @@ func TestTraceBlockTransactions(t *testing.T) { testSet := map[string][]testSetType{ "devnet": {}, // devenet doesn't support TraceBlockTransactions https://0xspaceshard.github.io/starknet-devnet/docs/guide/json-rpc-api#trace-api "mainnet": {}, - "testnet": { // TODO: there is a conflict between the test data and the rpc data, even though the data came from the same source... - // testSetType{ - // BlockID: WithBlockNumber(99433), - // ExpectedResp: blockTraceSepolia, - // ExpectedErr: nil, - // }, + "testnet": { + testSetType{ + BlockID: WithBlockNumber(99433), + ExpectedResp: blockTraceSepolia, + ExpectedErr: nil, + }, }, "mock": { testSetType{ @@ -202,10 +202,79 @@ func TestTraceBlockTransactions(t *testing.T) { resp, err := testConfig.provider.TraceBlockTransactions(context.Background(), test.BlockID) if err != nil { - require.Equal(t, test.ExpectedErr, err) + require.Equal(test.ExpectedErr, err) } else { - require.EqualValues(t, test.ExpectedResp, resp) + for i, trace := range resp { + require.Equal(test.ExpectedResp[i].TxnHash, trace.TxnHash) + compareTraceTxs(t, test.ExpectedResp[i].TraceRoot, trace.TraceRoot) + } + } + + } +} + +func compareTraceTxs(t *testing.T, traceTx1, traceTx2 TxnTrace) { + require := require.New(t) + + switch traceTx := traceTx1.(type) { + case DeclareTxnTrace: + require.Equal(traceTx.ValidateInvocation, traceTx2.(DeclareTxnTrace).ValidateInvocation) + require.Equal(traceTx.FeeTransferInvocation, traceTx2.(DeclareTxnTrace).FeeTransferInvocation) + compareStateDiffs(t, traceTx.StateDiff, traceTx2.(DeclareTxnTrace).StateDiff) + require.Equal(traceTx.Type, traceTx2.(DeclareTxnTrace).Type) + require.Equal(traceTx.ExecutionResources, traceTx2.(DeclareTxnTrace).ExecutionResources) + case DeployAccountTxnTrace: + require.Equal(traceTx.ValidateInvocation, traceTx2.(DeployAccountTxnTrace).ValidateInvocation) + require.Equal(traceTx.ConstructorInvocation, traceTx2.(DeployAccountTxnTrace).ConstructorInvocation) + require.Equal(traceTx.FeeTransferInvocation, traceTx2.(DeployAccountTxnTrace).FeeTransferInvocation) + compareStateDiffs(t, traceTx.StateDiff, traceTx2.(DeployAccountTxnTrace).StateDiff) + require.Equal(traceTx.Type, traceTx2.(DeployAccountTxnTrace).Type) + require.Equal(traceTx.ExecutionResources, traceTx2.(DeployAccountTxnTrace).ExecutionResources) + case InvokeTxnTrace: + require.Equal(traceTx.ValidateInvocation, traceTx2.(InvokeTxnTrace).ValidateInvocation) + require.Equal(traceTx.ExecuteInvocation, traceTx2.(InvokeTxnTrace).ExecuteInvocation) + require.Equal(traceTx.FeeTransferInvocation, traceTx2.(InvokeTxnTrace).FeeTransferInvocation) + compareStateDiffs(t, traceTx.StateDiff, traceTx2.(InvokeTxnTrace).StateDiff) + require.Equal(traceTx.Type, traceTx2.(InvokeTxnTrace).Type) + require.Equal(traceTx.ExecutionResources, traceTx2.(InvokeTxnTrace).ExecutionResources) + case L1HandlerTxnTrace: + require.Equal(traceTx.FunctionInvocation, traceTx2.(L1HandlerTxnTrace).FunctionInvocation) + compareStateDiffs(t, traceTx.StateDiff, traceTx2.(L1HandlerTxnTrace).StateDiff) + require.Equal(traceTx.Type, traceTx2.(L1HandlerTxnTrace).Type) + } +} + +func compareStateDiffs(t *testing.T, stateDiff1, stateDiff2 StateDiff) { + require.ElementsMatch(t, stateDiff1.DeprecatedDeclaredClasses, stateDiff2.DeprecatedDeclaredClasses) + require.ElementsMatch(t, stateDiff1.DeclaredClasses, stateDiff2.DeclaredClasses) + require.ElementsMatch(t, stateDiff1.DeployedContracts, stateDiff2.DeployedContracts) + require.ElementsMatch(t, stateDiff1.ReplacedClasses, stateDiff2.ReplacedClasses) + require.ElementsMatch(t, stateDiff1.Nonces, stateDiff2.Nonces) + + // compares storage diffs (they come in a random order) + rawStorageDiff, err := json.Marshal(stateDiff2.StorageDiffs) + require.NoError(t, err) + var mapDiff []map[string]interface{} + require.NoError(t, json.Unmarshal(rawStorageDiff, &mapDiff)) + + for _, diff1 := range stateDiff1.StorageDiffs { + var diff2 ContractStorageDiffItem + + for _, diffElem := range mapDiff { + address, ok := diffElem["address"] + require.True(t, ok) + addressFelt := utils.TestHexToFelt(t, address.(string)) + + if *addressFelt != *diff1.Address { + continue + } + + err = remarshal(diffElem, &diff2) + require.NoError(t, err) } + require.NotEmpty(t, diff2) + require.Equal(t, diff1.Address, diff2.Address) + require.ElementsMatch(t, diff1.StorageEntries, diff2.StorageEntries) } } diff --git a/rpc/types_trace.go b/rpc/types_trace.go index 49a2c604..5004a7aa 100644 --- a/rpc/types_trace.go +++ b/rpc/types_trace.go @@ -1,6 +1,12 @@ package rpc -import "github.com/NethermindEth/juno/core/felt" +import ( + "encoding/json" + "fmt" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/utils" +) type SimulateTransactionInput struct { //a sequence of transactions to simulate, running each transaction on the state resulting from applying all the previous ones @@ -24,8 +30,8 @@ type SimulateTransactionOutput struct { } type SimulatedTransaction struct { - TxnTrace `json:"transaction_trace"` - FeeEstimate + TxnTrace `json:"transaction_trace"` + FeeEstimate `json:"fee_estimation"` } type TxnTrace interface{} @@ -130,3 +136,127 @@ type ExecInvocation struct { FunctionInvocation FnInvocation `json:"function_invocation,omitempty"` RevertReason string `json:"revert_reason,omitempty"` } + +// UnmarshalJSON unmarshals the data into a SimulatedTransaction object. +// +// It takes a byte slice as the parameter, representing the JSON data to be unmarshalled. +// The function returns an error if the unmarshalling process fails. +// +// Parameters: +// - data: The JSON data to be unmarshalled +// Returns: +// - error: An error if the unmarshalling process fails +func (txn *SimulatedTransaction) UnmarshalJSON(data []byte) error { + var dec map[string]interface{} + if err := json.Unmarshal(data, &dec); err != nil { + return err + } + + // SimulatedTransaction wraps transactions in the TxnTrace field. + rawTxnTrace, err := utils.UnwrapJSON(dec, "transaction_trace") + if err != nil { + return err + } + + trace, err := unmarshalTraceTxn(rawTxnTrace) + if err != nil { + return err + } + + var feeEstimate FeeEstimate + + if feeEstimateData, ok := dec["fee_estimation"]; ok { + err = remarshal(feeEstimateData, &feeEstimate) + if err != nil { + return err + } + } else { + return fmt.Errorf("fee estimate not found") + } + + *txn = SimulatedTransaction{ + TxnTrace: trace, + FeeEstimate: feeEstimate, + } + return nil +} + +// UnmarshalJSON unmarshals the data into a Trace object. +// +// It takes a byte slice as the parameter, representing the JSON data to be unmarshalled. +// The function returns an error if the unmarshalling process fails. +// +// Parameters: +// - data: The JSON data to be unmarshalled +// Returns: +// - error: An error if the unmarshalling process fails +func (txn *Trace) UnmarshalJSON(data []byte) error { + var dec map[string]interface{} + if err := json.Unmarshal(data, &dec); err != nil { + return err + } + + // Trace wrap trace transactions in the TraceRoot field. + rawTraceTx, err := utils.UnwrapJSON(dec, "trace_root") + if err != nil { + return err + } + + t, err := unmarshalTraceTxn(rawTraceTx) + if err != nil { + return err + } + + var txHash *felt.Felt + if txHashData, ok := dec["transaction_hash"]; ok { + txHashString, ok := txHashData.(string) + if !ok { + return fmt.Errorf("failed to unmarshal transaction hash, transaction_hash is not a string") + } + txHash, err = utils.HexToFelt(txHashString) + if err != nil { + return err + } + } else { + return fmt.Errorf("failed to unmarshal transaction hash, transaction_hash not found") + } + + *txn = Trace{ + TraceRoot: t, + TxnHash: txHash, + } + return nil +} + +// unmarshalTraceTxn unmarshals a given interface and returns a TxnTrace. +// +// Parameter: +// - t: The interface{} to be unmarshalled +// Returns: +// - TxnTrace: a TxnTrace +// - error: an error if the unmarshaling process fails +func unmarshalTraceTxn(t interface{}) (TxnTrace, error) { + switch casted := t.(type) { + case map[string]interface{}: + switch TransactionType(casted["type"].(string)) { + case TransactionType_Declare: + var txn DeclareTxnTrace + err := remarshal(casted, &txn) + return txn, err + case TransactionType_DeployAccount: + var txn DeployAccountTxnTrace + err := remarshal(casted, &txn) + return txn, err + case TransactionType_Invoke: + var txn InvokeTxnTrace + err := remarshal(casted, &txn) + return txn, err + case TransactionType_L1Handler: + var txn L1HandlerTxnTrace + err := remarshal(casted, &txn) + return txn, err + } + } + + return nil, fmt.Errorf("unknown transaction type: %v", t) +}