From b385dd79c31175ef0157bf43d584caa268b38ba7 Mon Sep 17 00:00:00 2001 From: rianhughes Date: Tue, 22 Aug 2023 13:22:33 +0300 Subject: [PATCH] methods --- rpc/contract.go | 38 ++++++++++++++++++++++++++++++++------ rpc/provider.go | 4 ++-- rpc/types_contract.go | 5 +++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/rpc/contract.go b/rpc/contract.go index 8aa12310..de44b375 100644 --- a/rpc/contract.go +++ b/rpc/contract.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "encoding/json" "errors" "fmt" @@ -10,8 +11,8 @@ import ( ) // Class gets the contract class definition associated with the given hash. -func (provider *Provider) Class(ctx context.Context, blockID BlockID, classHash string) (*DepcreatedContractClass, error) { - var rawClass DepcreatedContractClass +func (provider *Provider) Class(ctx context.Context, blockID BlockID, classHash string) (GetClassOutput, error) { + var rawClass map[string]any if err := do(ctx, provider.c, "starknet_getClass", &rawClass, blockID, classHash); err != nil { switch { case errors.Is(err, ErrClassHashNotFound): @@ -21,12 +22,14 @@ func (provider *Provider) Class(ctx context.Context, blockID BlockID, classHash } return nil, err } - return &rawClass, nil + + return typecastClassOutut(&rawClass) + } // ClassAt get the contract class definition at the given address. -func (provider *Provider) ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*DepcreatedContractClass, error) { - var rawClass DepcreatedContractClass +func (provider *Provider) ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (GetClassOutput, error) { + var rawClass map[string]any if err := do(ctx, provider.c, "starknet_getClassAt", &rawClass, blockID, contractAddress); err != nil { switch { case errors.Is(err, ErrContractNotFound): @@ -36,7 +39,30 @@ func (provider *Provider) ClassAt(ctx context.Context, blockID BlockID, contract } return nil, err } - return &rawClass, nil + return typecastClassOutut(&rawClass) +} + +func typecastClassOutut(rawClass *map[string]any) (GetClassOutput, error) { + rawClassByte, err := json.Marshal(rawClass) + if err != nil { + return nil, err + } + + // if contract_class_version exists, then it's a ContractClass type + if _, exists := (*rawClass)["contract_class_version"]; exists { + var contractClass ContractClass + err = json.Unmarshal(rawClassByte, &contractClass) + if err != nil { + return nil, err + } + return &contractClass, nil + } + var depContractClass DepcreatedContractClass + err = json.Unmarshal(rawClassByte, &depContractClass) + if err != nil { + return nil, err + } + return &depContractClass, nil } // ClassHashAt gets the contract class hash for the contract deployed at the given address. diff --git a/rpc/provider.go b/rpc/provider.go index d96b050c..536c5524 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -32,8 +32,8 @@ type api interface { BlockWithTxs(ctx context.Context, blockID BlockID) (interface{}, error) Call(ctx context.Context, call FunctionCall, block BlockID) ([]*felt.Felt, error) ChainID(ctx context.Context) (string, error) - Class(ctx context.Context, blockID BlockID, classHash string) (*DepcreatedContractClass, error) - ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*DepcreatedContractClass, error) + Class(ctx context.Context, blockID BlockID, classHash string) (GetClassOutput, error) + ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (GetClassOutput, error) ClassHashAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*string, error) EstimateFee(ctx context.Context, requests []BroadcastedTransaction, blockID BlockID) ([]FeeEstimate, error) Events(ctx context.Context, input EventsInput) (*EventsOutput, error) diff --git a/rpc/types_contract.go b/rpc/types_contract.go index fc1b1e2c..ebd3e01c 100644 --- a/rpc/types_contract.go +++ b/rpc/types_contract.go @@ -19,6 +19,11 @@ type DeprecatedCairoEntryPoint struct { Selector *felt.Felt `json:"selector"` } +type GetClassOutput interface{} + +var _ GetClassOutput = &DepcreatedContractClass{} +var _ GetClassOutput = &ContractClass{} + type ABI []ABIEntry type DeprecatedEntryPointsByType struct {