diff --git a/go.mod b/go.mod index 47c652364..2f49a4736 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( github.com/libp2p/go-libp2p-kad-dht v0.28.1 github.com/libp2p/go-libp2p-pubsub v0.12.0 github.com/mitchellh/mapstructure v1.5.0 + github.com/libp2p/go-libp2p/p2p/net/mock v0.37.0 + github.com/libp2p/go-libp2p/core/protocol v0.37.0 github.com/multiformats/go-multiaddr v0.14.0 github.com/olekukonko/tablewriter v0.0.5 github.com/pkg/errors v0.9.1 diff --git a/p2p/p2p.go b/p2p/p2p.go index 49633f49e..21923cddf 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -142,7 +142,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai } } - p2pdht, err := makeDHT(p2phost, peersAddrInfoS) + p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork.L2ChainID) if err != nil { return nil, err } @@ -164,9 +164,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai return s, nil } -func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) { +func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, chainID string) (*dht.IpfsDHT, error) { return dht.New(context.Background(), p2phost, - dht.ProtocolPrefix(starknet.Prefix), + dht.ProtocolPrefix(starknet.ChainPID(chainID)), dht.BootstrapPeers(addrInfos...), dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod), dht.Mode(dht.ModeServer), @@ -282,11 +282,11 @@ func (s *Service) Run(ctx context.Context) error { } func (s *Service) setProtocolHandlers() { - s.SetProtocolHandler(starknet.HeadersPID(), s.handler.HeadersHandler) - s.SetProtocolHandler(starknet.EventsPID(), s.handler.EventsHandler) - s.SetProtocolHandler(starknet.TransactionsPID(), s.handler.TransactionsHandler) - s.SetProtocolHandler(starknet.ClassesPID(), s.handler.ClassesHandler) - s.SetProtocolHandler(starknet.StateDiffPID(), s.handler.StateDiffHandler) + s.SetProtocolHandler(starknet.HeadersPID(s.network.L2ChainID), s.handler.HeadersHandler) + s.SetProtocolHandler(starknet.EventsPID(s.network.L2ChainID), s.handler.EventsHandler) + s.SetProtocolHandler(starknet.TransactionsPID(s.network.L2ChainID), s.handler.TransactionsHandler) + s.SetProtocolHandler(starknet.ClassesPID(s.network.L2ChainID), s.handler.ClassesHandler) + s.SetProtocolHandler(starknet.StateDiffPID(s.network.L2ChainID), s.handler.StateDiffHandler) } func (s *Service) callAndLogErr(f func() error, msg string) { diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 070a9eedb..24a351c3a 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -206,3 +207,36 @@ func TestLoadAndPersistPeers(t *testing.T) { ) require.NoError(t, err) } + +func TestMakeDHTProtocolName(t *testing.T) { + net, err := mocknet.FullMeshLinked(1) + require.NoError(t, err) + testHost := net.Hosts()[0] + + testCases := []struct { + name string + network *utils.Network + expected string + }{ + { + name: "sepolia network", + network: &utils.Sepolia, + expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0", + }, + { + name: "mainnet network", + network: &utils.Mainnet, + expected: "/starknet/SN_MAIN/sync/kad/1.0.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dht, err := p2p.MakeDHT(testHost, nil, tc.network.L2ChainID) + require.NoError(t, err) + + protocols := dht.Host().Mux().Protocols() + assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols) + }) + } +} diff --git a/p2p/starknet/client.go b/p2p/starknet/client.go index bfeed7ab7..adfa79bd9 100644 --- a/p2p/starknet/client.go +++ b/p2p/starknet/client.go @@ -104,22 +104,24 @@ func (c *Client) RequestBlockHeaders( ctx context.Context, req *spec.BlockHeadersRequest, ) (iter.Seq[*spec.BlockHeadersResponse], error) { return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse]( - ctx, c.newStream, HeadersPID(), req, c.log) + ctx, c.newStream, HeadersPID(c.network.L2ChainID), req, c.log) } func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (iter.Seq[*spec.EventsResponse], error) { - return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log) + return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network.L2ChainID), req, c.log) } func (c *Client) RequestClasses(ctx context.Context, req *spec.ClassesRequest) (iter.Seq[*spec.ClassesResponse], error) { - return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log) + return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(c.network.L2ChainID), req, c.log) } func (c *Client) RequestStateDiffs(ctx context.Context, req *spec.StateDiffsRequest) (iter.Seq[*spec.StateDiffsResponse], error) { - return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log) + return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse]( + ctx, c.newStream, StateDiffPID(c.network.L2ChainID), req, c.log, + ) } func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (iter.Seq[*spec.TransactionsResponse], error) { return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse]( - ctx, c.newStream, TransactionsPID(), req, c.log) + ctx, c.newStream, TransactionsPID(c.network.L2ChainID), req, c.log) } diff --git a/p2p/starknet/ids.go b/p2p/starknet/ids.go index d1b97b0ad..14de5dbf5 100644 --- a/p2p/starknet/ids.go +++ b/p2p/starknet/ids.go @@ -6,22 +6,26 @@ import ( const Prefix = "/starknet" -func HeadersPID() protocol.ID { - return Prefix + "/headers/0.1.0-rc.0" +func HeadersPID(chainID string) protocol.ID { + return protocol.ID(Prefix + "/" + chainID + "/sync/headers/0.1.0-rc.0") } -func EventsPID() protocol.ID { - return Prefix + "/events/0.1.0-rc.0" +func EventsPID(chainID string) protocol.ID { + return protocol.ID(Prefix + "/" + chainID + "/sync/events/0.1.0-rc.0") } -func TransactionsPID() protocol.ID { - return Prefix + "/transactions/0.1.0-rc.0" +func TransactionsPID(chainID string) protocol.ID { + return protocol.ID(Prefix + "/" + chainID + "/sync/transactions/0.1.0-rc.0") } -func ClassesPID() protocol.ID { - return Prefix + "/classes/0.1.0-rc.0" +func ClassesPID(chainID string) protocol.ID { + return protocol.ID(Prefix + "/" + chainID + "/sync/classes/0.1.0-rc.0") } -func StateDiffPID() protocol.ID { - return Prefix + "/state_diffs/0.1.0-rc.0" +func StateDiffPID(chainID string) protocol.ID { + return protocol.ID(Prefix + "/" + chainID + "/sync/state_diffs/0.1.0-rc.0") +} + +func ChainPID(chainID string) protocol.ID { + return protocol.ID(Prefix + "/" + chainID + "/sync") } diff --git a/p2p/starknet/ids_test.go b/p2p/starknet/ids_test.go new file mode 100644 index 000000000..eb4adbc1f --- /dev/null +++ b/p2p/starknet/ids_test.go @@ -0,0 +1,66 @@ +package starknet + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProtocolIDs(t *testing.T) { + testCases := []struct { + name string + chainID string + pidFunc func(string) string + expected string + }{ + { + name: "HeadersPID with SN_MAIN", + chainID: "SN_MAIN", + pidFunc: func(c string) string { return string(HeadersPID(c)) }, + expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0", + }, + { + name: "EventsPID with SN_MAIN", + chainID: "SN_MAIN", + pidFunc: func(c string) string { return string(EventsPID(c)) }, + expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0", + }, + { + name: "TransactionsPID with SN_MAIN", + chainID: "SN_MAIN", + pidFunc: func(c string) string { return string(TransactionsPID(c)) }, + expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0", + }, + { + name: "ClassesPID with SN_MAIN", + chainID: "SN_MAIN", + pidFunc: func(c string) string { return string(ClassesPID(c)) }, + expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0", + }, + { + name: "StateDiffPID with SN_MAIN", + chainID: "SN_MAIN", + pidFunc: func(c string) string { return string(StateDiffPID(c)) }, + expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0", + }, + { + name: "ChainPID with SN_MAIN", + chainID: "SN_MAIN", + pidFunc: func(c string) string { return string(ChainPID(c)) }, + expected: "/starknet/SN_MAIN/sync", + }, + { + name: "HeadersPID with SN_SEPOLIA", + chainID: "SN_SEPOLIA", + pidFunc: func(c string) string { return string(HeadersPID(c)) }, + expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.pidFunc(tc.chainID) + assert.Equal(t, tc.expected, result) + }) + } +}