diff --git a/rpc/contract.go b/rpc/contract.go index 08eeb074..59754e46 100644 --- a/rpc/contract.go +++ b/rpc/contract.go @@ -108,3 +108,20 @@ func (provider *Provider) EstimateFee(ctx context.Context, requests []Broadcaste } return raw, nil } + +// EstimateMessageFee estimates the L2 fee of a message sent on L1 +func (provider *Provider) EstimateMessageFee(ctx context.Context, msg MsgFromL1, blockID BlockID) (*FeeEstimate, error) { + var raw FeeEstimate + if err := do(ctx, provider.c, "starknet_estimateMessageFee", &raw, msg, blockID); err != nil { + switch { + case errors.Is(err, ErrContractNotFound): + return nil, ErrContractNotFound + case errors.Is(err, ErrContractError): + return nil, ErrContractError + case errors.Is(err, ErrBlockNotFound): + return nil, ErrBlockNotFound + } + return nil, err + } + return &raw, nil +} diff --git a/rpc/contract_test.go b/rpc/contract_test.go index 828ec500..f4459bff 100644 --- a/rpc/contract_test.go +++ b/rpc/contract_test.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/utils" + "github.com/test-go/testify/require" ) // TestClassAt tests code for a class. @@ -261,3 +262,40 @@ func TestNonce(t *testing.T) { } } } + +// TestEstimateMessageFee tests EstimateMesssageFee +func TestEstimateMessageFee(t *testing.T) { + testConfig := beforeEach(t) + + type testSetType struct { + MsgFromL1 + BlockID + ExpectedFeeEst FeeEstimate + } + testSet := map[string][]testSetType{ + "mock": { + { + MsgFromL1: MsgFromL1{FromAddress: &felt.Zero, ToAddress: &felt.Zero, Selector: &felt.Zero, Payload: []*felt.Felt{&felt.Zero}}, + BlockID: BlockID{Tag: "latest"}, + ExpectedFeeEst: FeeEstimate{ + GasConsumed: NumAsHex("0x1"), + GasPrice: NumAsHex("0x2"), + OverallFee: NumAsHex("0x3"), + }, + }, + }, + "testnet": {}, + "mainnet": {}, + }[testEnv] + + for _, test := range testSet { + spy := NewSpy(testConfig.provider.c) + testConfig.provider.c = spy + value, err := testConfig.provider.EstimateMessageFee(context.Background(), test.MsgFromL1, test.BlockID) + if err != nil { + t.Fatal(err) + } + require.Equal(t, *value, test.ExpectedFeeEst) + + } +} diff --git a/rpc/mock_test.go b/rpc/mock_test.go index afb55748..ee964693 100644 --- a/rpc/mock_test.go +++ b/rpc/mock_test.go @@ -64,6 +64,8 @@ func (r *rpcMock) CallContext(ctx context.Context, result interface{}, method st return mock_starknet_addInvokeTransaction(result, method, args...) case "starknet_estimateFee": return mock_starknet_estimateFee(result, method, args...) + case "starknet_estimateMessageFee": + return mock_starknet_estimateMessageFee(result, method, args...) default: return errNotFound } @@ -437,6 +439,36 @@ func mock_starknet_estimateFee(result interface{}, method string, args ...interf return nil } +func mock_starknet_estimateMessageFee(result interface{}, method string, args ...interface{}) error { + r, ok := result.(*json.RawMessage) + if !ok { + return errWrongType + } + if len(args) != 2 { + fmt.Printf("args: %d\n", len(args)) + return errWrongArgs + } + _, ok = args[0].(MsgFromL1) + if !ok { + fmt.Printf("args[0] should be MsgFromL1, got %T\n", args[0]) + return errWrongArgs + } + _, ok = args[1].(BlockID) + if !ok { + fmt.Printf("args[1] should be *blockID, got %T\n", args[1]) + return errWrongArgs + } + + output := FeeEstimate{ + GasConsumed: NumAsHex("0x1"), + GasPrice: NumAsHex("0x2"), + OverallFee: NumAsHex("0x3"), + } + outputContent, _ := json.Marshal(output) + json.Unmarshal(outputContent, r) + return nil +} + func mock_starknet_addInvokeTransaction(result interface{}, method string, args ...interface{}) error { r, ok := result.(*json.RawMessage) if !ok { diff --git a/rpc/provider.go b/rpc/provider.go index fd8c877b..980fd4f4 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -36,6 +36,7 @@ type api interface { ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*ContractClass, error) ClassHashAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*string, error) EstimateFee(ctx context.Context, requests []BroadcastedTransaction, blockID BlockID) ([]FeeEstimate, error) + EstimateMessageFee(ctx context.Context, msg MsgFromL1, blockID BlockID) (*FeeEstimate, error) Events(ctx context.Context, input EventsInput) (*EventsOutput, error) Nonce(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*string, error) StateUpdate(ctx context.Context, blockID BlockID) (*StateUpdateOutput, error) diff --git a/rpc/types_transaction_receipt.go b/rpc/types_transaction_receipt.go index fb61ba7e..08fa85a4 100644 --- a/rpc/types_transaction_receipt.go +++ b/rpc/types_transaction_receipt.go @@ -134,6 +134,17 @@ type MsgToL1 struct { Payload []*felt.Felt `json:"payload"` } +type MsgFromL1 struct { + // FromAddress The address of the L1 contract sending the message + FromAddress *felt.Felt `json:"from_address"` + // ToAddress The target L2 address the message is sent to + ToAddress *felt.Felt `json:"to_address"` + // EntryPointSelector The selector of the l1_handler in invoke in the target contract + Selector *felt.Felt `json:"entry_point_selector"` + //Payload The payload of the message + Payload []*felt.Felt `json:"payload"` +} + type UnknownTransactionReceipt struct{ TransactionReceipt } func (tr *UnknownTransactionReceipt) UnmarshalJSON(data []byte) error {