Skip to content

Commit

Permalink
feat: provide maximum count to routing.FindProviders
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias committed May 12, 2023
1 parent 33e3f0c commit 9c2b259
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
6 changes: 3 additions & 3 deletions routing/http/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import (

type mockContentRouter struct{ mock.Mock }

func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key)
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key, count)
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
}
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) {
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestClient_FindProviders(t *testing.T) {

findProvsIter := iter.FromSlice(c.routerProvs)

router.On("FindProviders", mock.Anything, cid).
router.On("FindProviders", mock.Anything, cid, 20).
Return(findProvsIter, c.routerErr)

provsIter, err := client.FindProviders(ctx, cid)
Expand Down
36 changes: 31 additions & 5 deletions routing/http/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ type FindProvidersAsyncResponse struct {
}

type ContentRouter interface {
FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error)
// FindProviders searches for peers who are able to provide a given key. Count
// indicates the maximum amount of providers we are looking for. If count is 0,
// the implementer can return an unbounded number of results.
FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error)
ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error)
Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error)
}
Expand Down Expand Up @@ -69,9 +72,27 @@ func WithStreamingResultsDisabled() Option {
}
}

// WithRecordsCount changes the amount of records asked for non-streaming requests.
// Default is 20.
func WithRecordsCount(count int) Option {
return func(s *server) {
s.recordsCount = count
}
}

// WithStreamingRecordsCount changes the amount of records asked for streaming requests.
// Default is 0 (unbounded).
func WithStreamingRecordsCount(count int) Option {
return func(s *server) {
s.streamingRecordsCount = count
}
}

func Handler(svc ContentRouter, opts ...Option) http.Handler {
server := &server{
svc: svc,
svc: svc,
recordsCount: 20,
streamingRecordsCount: 0,
}

for _, opt := range opts {
Expand All @@ -86,8 +107,10 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler {
}

type server struct {
svc ContentRouter
disableNDJSON bool
svc ContentRouter
disableNDJSON bool
recordsCount int
streamingRecordsCount int
}

func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) {
Expand Down Expand Up @@ -170,6 +193,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {

var supportsNDJSON bool
var supportsJSON bool
var count int
acceptHeaders := httpReq.Header.Values("Accept")
if len(acceptHeaders) == 0 {
handlerFunc = s.findProvidersJSON
Expand All @@ -185,8 +209,10 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
switch mediaType {
case mediaTypeJSON, mediaTypeWildcard:
supportsJSON = true
count = s.recordsCount
case mediaTypeNDJSON:
supportsNDJSON = true
count = s.streamingRecordsCount
}
}
}
Expand All @@ -201,7 +227,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
}
}

provIter, err := s.svc.FindProviders(httpReq.Context(), cid)
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, count)
if err != nil {
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
return
Expand Down
6 changes: 3 additions & 3 deletions routing/http/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestHeaders(t *testing.T) {
cb, err := cid.Decode(c)
require.NoError(t, err)

router.On("FindProviders", mock.Anything, cb).
router.On("FindProviders", mock.Anything, cb, 0).
Return(results, nil)

resp, err := http.Get(serverAddr + ProvidePath + c)
Expand All @@ -50,8 +50,8 @@ func TestHeaders(t *testing.T) {

type mockContentRouter struct{ mock.Mock }

func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key)
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key, count)
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
}
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {
Expand Down

0 comments on commit 9c2b259

Please sign in to comment.