From 26f1bbfd7af3427f18afa700e6f561cf6a91d63b Mon Sep 17 00:00:00 2001 From: rianhughes Date: Fri, 4 Aug 2023 12:57:20 +0300 Subject: [PATCH] Update starknet_call --- account.go | 14 +++++++++----- accountgw_test.go | 15 ++++++++------- contracts/sessionkey.go | 3 ++- gateway/account.go | 3 ++- gateway/starknet.go | 14 +++++++------- gateway/starknet_test.go | 4 ++-- rpc/call.go | 8 ++------ rpc/call_test.go | 35 +++++++---------------------------- rpc/mock_test.go | 7 +++++-- rpc/provider.go | 2 +- 10 files changed, 45 insertions(+), 60 deletions(-) diff --git a/account.go b/account.go index 4d82e4a2..9a06b8e0 100644 --- a/account.go +++ b/account.go @@ -27,7 +27,7 @@ const ( type account interface { TransactionHash(calls []types.FunctionCall, details types.ExecuteDetails) (*big.Int, error) - Call(ctx context.Context, call types.FunctionCall) ([]string, error) + Call(ctx context.Context, call rpc.FunctionCall) ([]*felt.Felt, error) Nonce(ctx context.Context) (*big.Int, error) EstimateFee(ctx context.Context, calls []types.FunctionCall, details types.ExecuteDetails) (*types.FeeEstimate, error) Execute(ctx context.Context, calls []types.FunctionCall, details types.ExecuteDetails) (*types.AddInvokeTransactionOutput, error) @@ -44,12 +44,12 @@ type AccountPlugin interface { type ProviderType string const ( - ProviderRPC ProviderType = "rpc" + ProviderRPC ProviderType = "rpc" ProviderGateway ProviderType = "gateway" ) type Account struct { - rpc *rpc.Provider + rpc *rpc.Provider sequencer *gateway.GatewayProvider provider ProviderType chainId string @@ -145,7 +145,7 @@ func NewGatewayAccount(sender, address *felt.Felt, ks Keystore, provider *gatewa return account, nil } -func (account *Account) Call(ctx context.Context, call types.FunctionCall) ([]string, error) { +func (account *Account) Call(ctx context.Context, call rpc.FunctionCall) ([]*felt.Felt, error) { switch account.provider { case ProviderRPC: if account.rpc == nil { @@ -162,7 +162,11 @@ func (account *Account) Call(ctx context.Context, call types.FunctionCall) ([]st if account.sequencer == nil { return nil, ErrUnsupportedAccount } - return account.sequencer.Call(ctx, call, "latest") + resp, err := account.sequencer.Call(ctx, call, "latest") + if err != nil { + return nil, err + } + return utils.HexArrToFelt(resp) } return nil, ErrUnsupportedAccount } diff --git a/accountgw_test.go b/accountgw_test.go index a5d09d31..39cbd58f 100644 --- a/accountgw_test.go +++ b/accountgw_test.go @@ -6,22 +6,23 @@ import ( "testing" "time" + "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/types" "github.com/NethermindEth/starknet.go/utils" ) type TestAccountType struct { - PrivateKey string `json:"private_key"` - PublicKey string `json:"public_key"` - Address string `json:"address"` - Transactions []types.FunctionCall `json:"transactions,omitempty"` + PrivateKey string `json:"private_key"` + PublicKey string `json:"public_key"` + Address string `json:"address"` + Transactions []rpc.FunctionCall `json:"transactions,omitempty"` } func TestGatewayAccount_EstimateAndExecute(t *testing.T) { testConfig := beforeGatewayEach(t) type testSetType struct { ExecuteCalls []types.FunctionCall - QueryCall types.FunctionCall + QueryCall rpc.FunctionCall } testSet := map[string][]testSetType{ @@ -30,7 +31,7 @@ func TestGatewayAccount_EstimateAndExecute(t *testing.T) { EntryPointSelector: types.GetSelectorFromNameFelt("increment"), ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress), }}, - QueryCall: types.FunctionCall{ + QueryCall: rpc.FunctionCall{ EntryPointSelector: types.GetSelectorFromNameFelt("get_count"), ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress), }, @@ -40,7 +41,7 @@ func TestGatewayAccount_EstimateAndExecute(t *testing.T) { EntryPointSelector: types.GetSelectorFromNameFelt("increment"), ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress), }}, - QueryCall: types.FunctionCall{ + QueryCall: rpc.FunctionCall{ EntryPointSelector: types.GetSelectorFromNameFelt("get_count"), ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress), }, diff --git a/contracts/sessionkey.go b/contracts/sessionkey.go index 15a7efac..468d6c4a 100644 --- a/contracts/sessionkey.go +++ b/contracts/sessionkey.go @@ -12,6 +12,7 @@ import ( starknetgo "github.com/NethermindEth/starknet.go" "github.com/NethermindEth/starknet.go/gateway" "github.com/NethermindEth/starknet.go/plugins/xsessions" + "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/types" "github.com/NethermindEth/starknet.go/utils" ) @@ -138,7 +139,7 @@ func (ap *AccountManager) ExecuteWithGateway(counterAddress *felt.Felt, selector return tx.TransactionHash.String(), nil } -func (ap *AccountManager) CallWithGateway(call types.FunctionCall, provider *gateway.GatewayProvider) ([]string, error) { +func (ap *AccountManager) CallWithGateway(call rpc.FunctionCall, provider *gateway.GatewayProvider) ([]*felt.Felt, error) { // shim in the keystore. while weird and awkward, it's functionally ok because // 1. account manager doesn't seem to be used any where // 2. the account that is created below is scoped to this func diff --git a/gateway/account.go b/gateway/account.go index 68d017e4..6e396a70 100644 --- a/gateway/account.go +++ b/gateway/account.go @@ -12,11 +12,12 @@ import ( "strings" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/types" ) func (sg *Gateway) AccountNonce(ctx context.Context, address *felt.Felt) (*big.Int, error) { - resp, err := sg.Call(ctx, types.FunctionCall{ + resp, err := sg.Call(ctx, rpc.FunctionCall{ ContractAddress: address, EntryPointSelector: types.GetSelectorFromNameFelt("get_nonce"), }, "") diff --git a/gateway/starknet.go b/gateway/starknet.go index 813fc64e..a975118a 100644 --- a/gateway/starknet.go +++ b/gateway/starknet.go @@ -63,7 +63,7 @@ func (f FunctionCall) MarshalJSON() ([]byte, error) { /* 'call_contract' wrapper and can accept a blockId in the hash or height format */ -func (sg *Gateway) Call(ctx context.Context, call types.FunctionCall, blockHashOrTag string) ([]string, error) { +func (sg *Gateway) Call(ctx context.Context, call rpc.FunctionCall, blockHashOrTag string) ([]string, error) { gc := GatewayFunctionCall{ FunctionCall: FunctionCall(call), } @@ -220,12 +220,12 @@ func (sg *Gateway) Declare(ctx context.Context, contract rpc.ContractClass, decl // } type DeclareRequest struct { - Type string `json:"type"` - SenderAddress *felt.Felt `json:"sender_address"` - Version string `json:"version"` - MaxFee string `json:"max_fee"` - Nonce string `json:"nonce"` - Signature []string `json:"signature"` + Type string `json:"type"` + SenderAddress *felt.Felt `json:"sender_address"` + Version string `json:"version"` + MaxFee string `json:"max_fee"` + Nonce string `json:"nonce"` + Signature []string `json:"signature"` ContractClass rpc.ContractClass `json:"contract_class"` } diff --git a/gateway/starknet_test.go b/gateway/starknet_test.go index 062ebbc6..f0c1e00b 100644 --- a/gateway/starknet_test.go +++ b/gateway/starknet_test.go @@ -208,12 +208,12 @@ func TestCall(t *testing.T) { testConfig := beforeEach(t) type testSetType struct { - Call types.FunctionCall + Call rpc.FunctionCall } testSet := map[string][]testSetType{ "devnet": { { - Call: types.FunctionCall{ + Call: rpc.FunctionCall{ ContractAddress: utils.TestHexToFelt(t, counterAddress), EntryPointSelector: types.GetSelectorFromNameFelt("get_count"), Calldata: []*felt.Felt{}, diff --git a/rpc/call.go b/rpc/call.go index 2b95e26a..11b96652 100644 --- a/rpc/call.go +++ b/rpc/call.go @@ -8,20 +8,16 @@ import ( ) // Call a starknet function without creating a StarkNet transaction. -func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockID BlockID) ([]string, error) { +func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockID BlockID) ([]*felt.Felt, error) { if len(request.Calldata) == 0 { request.Calldata = make([]*felt.Felt, 0) } - var result []string + var result []*felt.Felt if err := do(ctx, provider.c, "starknet_call", &result, request, blockID); err != nil { switch { case errors.Is(err, ErrContractNotFound): return nil, ErrContractNotFound - case errors.Is(err, ErrInvalidMessageSelector): - return nil, ErrInvalidMessageSelector - case errors.Is(err, ErrInvalidCallData): - return nil, ErrInvalidCallData case errors.Is(err, ErrContractError): return nil, ErrContractError case errors.Is(err, ErrBlockNotFound): diff --git a/rpc/call_test.go b/rpc/call_test.go index 54009ecb..c8694c7f 100644 --- a/rpc/call_test.go +++ b/rpc/call_test.go @@ -2,12 +2,12 @@ package rpc import ( "context" - "regexp" "testing" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/types" "github.com/NethermindEth/starknet.go/utils" + "github.com/test-go/testify/require" ) // TestCall tests Call @@ -17,7 +17,7 @@ func TestCall(t *testing.T) { type testSetType struct { FunctionCall FunctionCall BlockID BlockID - ExpectedPatternResult string + ExpectedPatternResult *felt.Felt } testSet := map[string][]testSetType{ "devnet": { @@ -29,7 +29,7 @@ func TestCall(t *testing.T) { Calldata: []*felt.Felt{}, }, BlockID: WithBlockTag("latest"), - ExpectedPatternResult: "^0x[0-9a-f]+$", + ExpectedPatternResult: utils.TestHexToFelt(t, "0x6574686572"), }, }, "mock": { @@ -40,7 +40,7 @@ func TestCall(t *testing.T) { Calldata: []*felt.Felt{}, }, BlockID: WithBlockTag("latest"), - ExpectedPatternResult: "^0x12$", + ExpectedPatternResult: utils.TestHexToFelt(t, "0xdeadbeef"), }, }, "testnet": { @@ -51,25 +51,7 @@ func TestCall(t *testing.T) { Calldata: []*felt.Felt{}, }, BlockID: WithBlockTag("latest"), - ExpectedPatternResult: "^0x12$", - }, - { - FunctionCall: FunctionCall{ - ContractAddress: utils.TestHexToFelt(t, TestNetETHAddress), - EntryPointSelector: types.GetSelectorFromNameFelt("balanceOf"), - Calldata: []*felt.Felt{utils.TestHexToFelt(t, "0x0207aCC15dc241e7d167E67e30E769719A727d3E0fa47f9E187707289885Dfde")}, - }, - BlockID: WithBlockTag("latest"), - ExpectedPatternResult: "^0x[0-9a-f]+$", - }, - { - FunctionCall: FunctionCall{ - ContractAddress: utils.TestHexToFelt(t, TestNetAccount032Address), - EntryPointSelector: types.GetSelectorFromNameFelt("get_nonce"), - Calldata: []*felt.Felt{}, - }, - BlockID: WithBlockTag("latest"), - ExpectedPatternResult: "^0x[0-9a-f]+$", + ExpectedPatternResult: utils.TestHexToFelt(t, "0x12"), }, }, "mainnet": { @@ -80,7 +62,7 @@ func TestCall(t *testing.T) { Calldata: []*felt.Felt{}, }, BlockID: WithBlockTag("latest"), - ExpectedPatternResult: "^0x12$", + ExpectedPatternResult: utils.TestHexToFelt(t, "0x12"), }, }, }[testEnv] @@ -99,10 +81,7 @@ func TestCall(t *testing.T) { if len(output) == 0 { t.Fatal("should return an output") } - match, err := regexp.Match(test.ExpectedPatternResult, []byte(output[0])) - if err != nil || !match { - t.Fatalf("checking output(%v) expecting %s, got: %v", err, test.ExpectedPatternResult, output[0]) - } + require.Equal(t, test.ExpectedPatternResult, output[0]) } } diff --git a/rpc/mock_test.go b/rpc/mock_test.go index 7a5af083..7f1e82dd 100644 --- a/rpc/mock_test.go +++ b/rpc/mock_test.go @@ -359,8 +359,11 @@ func mock_starknet_call(result interface{}, method string, args ...interface{}) fmt.Printf("args: %d\n", len(args)) return errWrongArgs } - output := []string{"0x12"} - outputContent, _ := json.Marshal(output) + out, err := new(felt.Felt).SetString("0xdeadbeef") + if err != nil { + return err + } + outputContent, _ := json.Marshal([]*felt.Felt{out}) json.Unmarshal(outputContent, r) return nil } diff --git a/rpc/provider.go b/rpc/provider.go index 75b8419b..0741b3d0 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -30,7 +30,7 @@ type api interface { BlockTransactionCount(ctx context.Context, blockID BlockID) (uint64, error) BlockWithTxHashes(ctx context.Context, blockID BlockID) (interface{}, error) BlockWithTxs(ctx context.Context, blockID BlockID) (interface{}, error) - Call(ctx context.Context, call FunctionCall, block BlockID) ([]string, error) + Call(ctx context.Context, call FunctionCall, block BlockID) ([]*felt.Felt, error) ChainID(ctx context.Context) (string, error) Class(ctx context.Context, blockID BlockID, classHash string) (*ContractClass, error) ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*ContractClass, error)