diff --git a/src/dbnode/network/server/tchannelthrift/node/service.go b/src/dbnode/network/server/tchannelthrift/node/service.go index 08ca29c435..ce95b1dc98 100644 --- a/src/dbnode/network/server/tchannelthrift/node/service.go +++ b/src/dbnode/network/server/tchannelthrift/node/service.go @@ -255,6 +255,7 @@ type Service interface { rpc.TChanNode // FetchTaggedIter returns an iterator for the results of FetchTagged. + // It is the responsibility of the caller to close the returned iterator. FetchTaggedIter(ctx context.Context, req *rpc.FetchTaggedRequest) (FetchTaggedResultsIter, error) // Only safe to be called one time once the service has started. @@ -722,35 +723,19 @@ func (s *service) readDatapoints( } func (s *service) FetchTagged(tctx thrift.Context, req *rpc.FetchTaggedRequest) (*rpc.FetchTaggedResult_, error) { - callStart := s.nowFn() - - ctx := addSourceToContext(tctx, req.Source) - ctx, sp, sampled := ctx.StartSampledTraceSpan(tracepoint.FetchTagged) - if sampled { - sp.LogFields( - opentracinglog.String("query", string(req.Query)), - opentracinglog.String("namespace", string(req.NameSpace)), - xopentracing.Time("start", time.Unix(0, req.RangeStart)), - xopentracing.Time("end", time.Unix(0, req.RangeEnd)), - ) - } - - result, err := s.fetchTagged(ctx, req) - if sampled && err != nil { - sp.LogFields(opentracinglog.Error(err)) - } - sp.Finish() - - s.metrics.fetchTagged.ReportSuccessOrError(err, s.nowFn().Sub(callStart)) - return result, err -} - -func (s *service) fetchTagged(ctx context.Context, req *rpc.FetchTaggedRequest) (*rpc.FetchTaggedResult_, error) { + ctx := tchannelthrift.Context(tctx) iter, err := s.FetchTaggedIter(ctx, req) if err != nil { return nil, err } + result, err := s.buildFetchTaggedResult(ctx, iter) + iter.Close(err) + + return result, err +} +func (s *service) buildFetchTaggedResult(ctx context.Context, iter FetchTaggedResultsIter) (*rpc.FetchTaggedResult_, + error) { response := &rpc.FetchTaggedResult_{ Elements: make([]*rpc.FetchTaggedIDResult_, 0, iter.NumIDs()), Exhaustive: iter.Exhaustive(), @@ -781,6 +766,35 @@ func (s *service) fetchTagged(ctx context.Context, req *rpc.FetchTaggedRequest) } func (s *service) FetchTaggedIter(ctx context.Context, req *rpc.FetchTaggedRequest) (FetchTaggedResultsIter, error) { + callStart := s.nowFn() + ctx = addSourceToM3Context(ctx, req.Source) + ctx, sp, sampled := ctx.StartSampledTraceSpan(tracepoint.FetchTagged) + if sampled { + sp.LogFields( + opentracinglog.String("query", string(req.Query)), + opentracinglog.String("namespace", string(req.NameSpace)), + xopentracing.Time("start", time.Unix(0, req.RangeStart)), + xopentracing.Time("end", time.Unix(0, req.RangeEnd)), + ) + } + + instrumentClose := func(err error) { + if sampled && err != nil { + sp.LogFields(opentracinglog.Error(err)) + } + sp.Finish() + + s.metrics.fetchTagged.ReportSuccessOrError(err, s.nowFn().Sub(callStart)) + } + iter, err := s.fetchTaggedIter(ctx, req, instrumentClose) + if err != nil { + instrumentClose(err) + } + return iter, err +} + +func (s *service) fetchTaggedIter(ctx context.Context, req *rpc.FetchTaggedRequest, instrumentClose func(error)) ( + FetchTaggedResultsIter, error) { db, err := s.startReadRPCWithDB() if err != nil { return nil, err @@ -803,14 +817,15 @@ func (s *service) FetchTaggedIter(ctx context.Context, req *rpc.FetchTaggedReque ctx.RegisterFinalizer(tagEncoder) return newFetchTaggedResultsIter(fetchTaggedResultsIterOpts{ - queryResult: queryResult, - queryOpts: opts, - fetchData: fetchData, - db: db, - docReader: docs.NewEncodedDocumentReader(), - nsID: ns, - tagEncoder: tagEncoder, - iOpts: s.opts.InstrumentOptions(), + queryResult: queryResult, + queryOpts: opts, + fetchData: fetchData, + db: db, + docReader: docs.NewEncodedDocumentReader(), + nsID: ns, + tagEncoder: tagEncoder, + iOpts: s.opts.InstrumentOptions(), + instrumentClose: instrumentClose, }), nil } @@ -837,6 +852,10 @@ type FetchTaggedResultsIter interface { // Current returns the current IDResult fetched with Next. The result is only valid if Err is nil. Current() IDResult + + // Close closes the iterator. The provided error is non-nil if the client of the Iterator encountered an error + // while iterating. + Close(err error) } type fetchTaggedResultsIter struct { @@ -855,17 +874,19 @@ type fetchTaggedResultsIter struct { docReader *docs.EncodedDocumentReader tagEncoder serialize.TagEncoder iOpts instrument.Options + instrumentClose func(error) } type fetchTaggedResultsIterOpts struct { - queryResult index.QueryResult - queryOpts index.QueryOptions - fetchData bool - db storage.Database - docReader *docs.EncodedDocumentReader - nsID ident.ID - tagEncoder serialize.TagEncoder - iOpts instrument.Options + queryResult index.QueryResult + queryOpts index.QueryOptions + fetchData bool + db storage.Database + docReader *docs.EncodedDocumentReader + nsID ident.ID + tagEncoder serialize.TagEncoder + iOpts instrument.Options + instrumentClose func(error) } func newFetchTaggedResultsIter(opts fetchTaggedResultsIterOpts) FetchTaggedResultsIter { //nolint: gocritic @@ -881,6 +902,7 @@ func newFetchTaggedResultsIter(opts fetchTaggedResultsIterOpts) FetchTaggedResul docReader: opts.docReader, tagEncoder: opts.tagEncoder, iOpts: opts.iOpts, + instrumentClose: opts.instrumentClose, } return iter @@ -903,10 +925,10 @@ func (i *fetchTaggedResultsIter) Next(ctx context.Context) bool { if i.idx == 0 { for _, entry := range i.queryResults.Iter() { // nolint: gocritic result := idResult{ - queryResult: entry, - docReader: i.docReader, - tagEncoder: i.tagEncoder, - iOpts: i.iOpts, + queryResult: entry, + docReader: i.docReader, + tagEncoder: i.tagEncoder, + iOpts: i.iOpts, } if i.fetchData { // NB(r): Use a bytes ID here so that this ID doesn't need to be @@ -959,6 +981,10 @@ func (i *fetchTaggedResultsIter) Current() IDResult { return i.cur } +func (i *fetchTaggedResultsIter) Close(err error) { + i.instrumentClose(err) +} + // IDResult is the FetchTagged result for a series ID. type IDResult interface { // ID returns the series ID. @@ -2790,7 +2816,10 @@ func finalizeAnnotationFn(b []byte) { } func addSourceToContext(tctx thrift.Context, source []byte) context.Context { - ctx := tchannelthrift.Context(tctx) + return addSourceToM3Context(tchannelthrift.Context(tctx), source) +} + +func addSourceToM3Context(ctx context.Context, source []byte) context.Context { if len(source) > 0 { if base, ok := ctx.GoContext(); ok { ctx.SetGoContext(goctx.WithValue(base, limits.SourceContextKey, source))