diff --git a/pkg/server/server_sql.go b/pkg/server/server_sql.go index c00a4f60d728..a09a90a7cc01 100644 --- a/pkg/server/server_sql.go +++ b/pkg/server/server_sql.go @@ -905,42 +905,9 @@ func newSQLServer(ctx context.Context, cfg sqlServerArgs) (*SQLServer, error) { } } - // Setup the trace collector that is used to fetch inflight trace spans from - // all nodes in the cluster. - // The collector requires nodeliveness to get a list of all the nodes in the - // cluster. - var getNodes func(ctx context.Context) ([]roachpb.NodeID, error) - if isMixedSQLAndKVNode && hasNodeLiveness { - // TODO(dt): any reason not to just always use the instance reader? And just - // pass it directly instead of making a new closure here? - getNodes = func(ctx context.Context) ([]roachpb.NodeID, error) { - var ns []roachpb.NodeID - ls, err := nodeLiveness.GetLivenessesFromKV(ctx) - if err != nil { - return nil, err - } - for _, l := range ls { - if l.Membership.Decommissioned() { - continue - } - ns = append(ns, l.NodeID) - } - return ns, nil - } - } else { - getNodes = func(ctx context.Context) ([]roachpb.NodeID, error) { - instances, err := cfg.sqlInstanceReader.GetAllInstances(ctx) - if err != nil { - return nil, err - } - instanceIDs := make([]roachpb.NodeID, len(instances)) - for i, instance := range instances { - instanceIDs[i] = roachpb.NodeID(instance.InstanceID) - } - return instanceIDs, err - } - } - traceCollector := collector.New(cfg.Tracer, getNodes, cfg.podNodeDialer) + // Set up the trace collector that is used to fetch inflight trace spans + // from all instances in the cluster. + traceCollector := collector.New(cfg.Tracer, cfg.sqlInstanceReader.GetAllInstances, cfg.podNodeDialer) contentionMetrics := contention.NewMetrics() cfg.registry.AddMetricStruct(contentionMetrics) diff --git a/pkg/sql/crdb_internal.go b/pkg/sql/crdb_internal.go index b8a76105f2f6..b22f3ea3c146 100644 --- a/pkg/sql/crdb_internal.go +++ b/pkg/sql/crdb_internal.go @@ -97,7 +97,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" - "github.com/cockroachdb/cockroach/pkg/util/tracing/collector" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" @@ -1944,8 +1943,11 @@ CREATE TABLE crdb_internal.cluster_inflight_traces ( } traceCollector := p.ExecCfg().TraceCollector - var iter *collector.Iterator - for iter, err = traceCollector.StartIter(ctx, traceID); err == nil && iter.Valid(); iter.Next(ctx) { + iter, err := traceCollector.StartIter(ctx, traceID) + if err != nil { + return false, err + } + for ; iter.Valid(); iter.Next(ctx) { nodeID, recording := iter.Value() traceString := recording.String() traceJaegerJSON, err := recording.ToJaegerJSON("", "", fmt.Sprintf("node %d", nodeID)) @@ -1960,12 +1962,6 @@ CREATE TABLE crdb_internal.cluster_inflight_traces ( return false, err } } - if err != nil { - return false, err - } - if iter.Error() != nil { - return false, iter.Error() - } return true, nil }}}, diff --git a/pkg/sql/crdb_internal_test.go b/pkg/sql/crdb_internal_test.go index f11335d3a1a9..2f6cb8015e38 100644 --- a/pkg/sql/crdb_internal_test.go +++ b/pkg/sql/crdb_internal_test.go @@ -778,7 +778,12 @@ func TestClusterInflightTracesVirtualTable(t *testing.T) { }, } var rowIdx int - rows := sqlDB.Query(t, `SELECT trace_id, node_id, trace_str, jaeger_json from crdb_internal.cluster_inflight_traces WHERE trace_id=$1`, traceID) + rows := sqlDB.Query(t, ` + SELECT trace_id, node_id, trace_str, jaeger_json + FROM crdb_internal.cluster_inflight_traces + WHERE trace_id = $1 + ORDER BY node_id;`, // sort by node_id in case instances are returned out of order + traceID) defer rows.Close() for rows.Next() { var traceID, nodeID int diff --git a/pkg/util/tracing/collector/BUILD.bazel b/pkg/util/tracing/collector/BUILD.bazel index 7c0903a53316..2c7bc77f7e4a 100644 --- a/pkg/util/tracing/collector/BUILD.bazel +++ b/pkg/util/tracing/collector/BUILD.bazel @@ -6,10 +6,11 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/util/tracing/collector", visibility = ["//visibility:public"], deps = [ - "//pkg/kv/kvserver/liveness/livenesspb", + "//pkg/base", "//pkg/roachpb", "//pkg/rpc", "//pkg/rpc/nodedialer", + "//pkg/sql/sqlinstance", "//pkg/util/log", "//pkg/util/tracing", "//pkg/util/tracing/tracingpb", @@ -34,6 +35,7 @@ go_test( "//pkg/security/securitytest", "//pkg/security/username", "//pkg/server", + "//pkg/sql/sqlinstance", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", diff --git a/pkg/util/tracing/collector/collector.go b/pkg/util/tracing/collector/collector.go index 6832568b912f..88a91610a77e 100644 --- a/pkg/util/tracing/collector/collector.go +++ b/pkg/util/tracing/collector/collector.go @@ -14,93 +14,89 @@ import ( "context" "sort" - "github.com/cockroachdb/cockroach/pkg/kv/kvserver/liveness/livenesspb" + "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/rpc/nodedialer" + "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingservicepb" ) -// NodeLiveness is the subset of the interface satisfied by CRDB's node liveness -// component that the tracing service relies upon. -type NodeLiveness interface { - GetLivenessesFromKV(context.Context) ([]livenesspb.Liveness, error) -} - // TraceCollector can be used to extract recordings from inflight spans for a -// given traceID, from all nodes of the cluster. +// given traceID, from all SQL instances. type TraceCollector struct { - tracer *tracing.Tracer - getNodes func(ctx context.Context) ([]roachpb.NodeID, error) - dialer *nodedialer.Dialer + tracer *tracing.Tracer + getInstances func(context.Context) ([]sqlinstance.InstanceInfo, error) + dialer *nodedialer.Dialer } // New returns a TraceCollector. +// +// Note that the second argument is not *instancestorage.Reader but an accessor +// method to allow for easier testing setup. func New( tracer *tracing.Tracer, - getNodes func(ctx context.Context) ([]roachpb.NodeID, error), + getInstances func(context.Context) ([]sqlinstance.InstanceInfo, error), dialer *nodedialer.Dialer, ) *TraceCollector { return &TraceCollector{ - tracer: tracer, - getNodes: getNodes, - dialer: dialer, + tracer: tracer, + getInstances: getInstances, + dialer: dialer, } } -// Iterator can be used to return tracing.Recordings from all live nodes in the +// Iterator can be used to return tracing.Recordings from all instances in the // cluster, in a streaming manner. The iterator buffers the tracing.Recordings -// of one node at a time. +// of one instance at a time. type Iterator struct { collector *TraceCollector traceID tracingpb.TraceID - // nodes stores all the nodes in the cluster (either mixed nodes or tenant - // servers) that will be contacted for inflight trace spans by the iterator. - // When they refer to tenant servers, the NodeIDs are really InstanceIDs. - nodes []roachpb.NodeID + // instances stores all SQL instances in the cluster that will be contacted + // for inflight trace spans by the iterator. + instances []sqlinstance.InstanceInfo - // curNodeIndex maintains the index in nodes from which the iterator has - // pulled inflight span recordings and buffered them in `recordedSpans` for - // consumption via the iterator. - curNodeIndex int + // curInstanceIdx maintains the index in instances from which the iterator + // has pulled inflight span recordings and buffered them in `recordedSpans` + // for consumption via the iterator. + curInstanceIdx int - // curNode maintains the node from which the iterator has pulled inflight span - // recordings and buffered them in `recordings` for consumption via the - // iterator. - curNode roachpb.NodeID + // curInstanceID maintains the instance ID from which the iterator has + // pulled inflight span recordings and buffered them in `recordings` for + // consumption via the iterator. + curInstanceID base.SQLInstanceID // recordingIndex maintains the current position of the iterator in the list - // of tracing.Recordings. The tracingpb.Recording that the iterator points to is - // buffered in `recordings`. + // of tracing.Recordings. The tracingpb.Recording that the iterator points + // to is buffered in `recordings`. recordingIndex int - // recordings represent all the tracing.Recordings for a given node currently - // accessed by the iterator. + // recordings represent all the tracing.Recordings for a given SQL instance + // currently accessed by the iterator. recordings []tracingpb.Recording - - iterErr error } -// StartIter fetches the live nodes in the cluster, and configures the underlying -// Iterator that is used to access recorded spans in a streaming fashion. +// StartIter fetches all SQL instances in the cluster, and configures the +// underlying Iterator that is used to access recorded spans in a streaming +// fashion. func (t *TraceCollector) StartIter( ctx context.Context, traceID tracingpb.TraceID, ) (*Iterator, error) { tc := &Iterator{traceID: traceID, collector: t} var err error - tc.nodes, err = t.getNodes(ctx) + tc.instances, err = t.getInstances(ctx) if err != nil { return nil, err } // Calling Next() positions the Iterator in a valid state. It will fetch the - // first set of valid (non-nil) inflight span recordings from the list of live - // nodes. + // first set of valid (non-nil) inflight span recordings from the list of + // SQL instances. tc.Next(ctx) return tc, nil @@ -108,19 +104,7 @@ func (t *TraceCollector) StartIter( // Valid returns whether the Iterator is in a valid state to read values from. func (i *Iterator) Valid() bool { - if i.iterErr != nil { - return false - } - - // If recordingIndex is within recordings and there are some buffered - // recordings, it is valid to return from the buffer. - if i.recordings != nil && i.recordingIndex < len(i.recordings) { - return true - } - - // Otherwise, we have exhausted inflight span recordings from all live nodes - // in the cluster. - return false + return i.recordingIndex < len(i.recordings) } // Next sets the Iterator to point to the next value to be returned. @@ -129,7 +113,7 @@ func (i *Iterator) Next(ctx context.Context) { // If recordingIndex is within recordings and there are some buffered // recordings, it is valid to return from the buffer. - if i.recordings != nil && i.recordingIndex < len(i.recordings) { + if i.recordingIndex < len(i.recordings) { return } @@ -137,57 +121,52 @@ func (i *Iterator) Next(ctx context.Context) { i.recordings = nil i.recordingIndex = 0 - // Either there are no more spans or we have exhausted the recordings from the - // current node, and we need to pull the inflight recordings from another - // node. - // Keep searching for recordings from all live nodes in the cluster. - for i.recordings == nil { - // No more spans to return from any of the live nodes in the cluster. - if !(i.curNodeIndex < len(i.nodes)) { + // Either there are no more spans or we have exhausted the recordings from + // the current instance, and we need to pull the inflight recordings from + // another instance. + // Keep searching for recordings from all SQL instances in the cluster. + for len(i.recordings) == 0 { + // No more spans to return from any of the SQL instances in the cluster. + if !(i.curInstanceIdx < len(i.instances)) { return } - i.curNode = i.nodes[i.curNodeIndex] - i.recordings, i.iterErr = i.collector.getTraceSpanRecordingsForNode(ctx, i.traceID, i.curNode) - // TODO(adityamaru): We might want to consider not failing if a single node - // fails to return span recordings. - if i.iterErr != nil { - return - } - i.curNodeIndex++ + i.curInstanceID = i.instances[i.curInstanceIdx].InstanceID + i.recordings = i.collector.getTraceSpanRecordingsForInstance(ctx, i.traceID, i.curInstanceID) + i.curInstanceIdx++ } } // Value returns the current value pointed to by the Iterator. -func (i *Iterator) Value() (roachpb.NodeID, tracingpb.Recording) { - return i.curNode, i.recordings[i.recordingIndex] +func (i *Iterator) Value() (base.SQLInstanceID, tracingpb.Recording) { + return i.curInstanceID, i.recordings[i.recordingIndex] } -// Error returns the error encountered by the Iterator during iteration. -func (i *Iterator) Error() error { - return i.iterErr -} - -// getTraceSpanRecordingsForNode returns the inflight span recordings for traces -// with traceID from the node with nodeID. The span recordings are sorted by -// StartTime. +// getTraceSpanRecordingsForInstance returns the inflight span recordings for +// traces with traceID from the SQL instance with the given ID. The span +// recordings are sorted by StartTime. If any error is encountered, then nil +// slice is returned. +// // This method does not distinguish between requests for local and remote -// inflight spans, and relies on gRPC short circuiting local requests. -func (t *TraceCollector) getTraceSpanRecordingsForNode( - ctx context.Context, traceID tracingpb.TraceID, nodeID roachpb.NodeID, -) ([]tracingpb.Recording, error) { - log.Infof(ctx, "getting span recordings from node %s", nodeID.String()) - conn, err := t.dialer.Dial(ctx, nodeID, rpc.DefaultClass) +// inflight spans, and relies on gRPC short-circuiting local requests. +func (t *TraceCollector) getTraceSpanRecordingsForInstance( + ctx context.Context, traceID tracingpb.TraceID, instanceID base.SQLInstanceID, +) []tracingpb.Recording { + log.Infof(ctx, "getting span recordings from instance %s", instanceID) + conn, err := t.dialer.Dial(ctx, roachpb.NodeID(instanceID), rpc.DefaultClass) if err != nil { - return nil, err + log.Warningf(ctx, "failed to dial instance %s: %v", instanceID, err) + return nil } traceClient := tracingservicepb.NewTracingClient(conn) - resp, err := traceClient.GetSpanRecordings(ctx, + var resp *tracingservicepb.GetSpanRecordingsResponse + resp, err = traceClient.GetSpanRecordings(ctx, &tracingservicepb.GetSpanRecordingsRequest{TraceID: traceID}) if err != nil { - return nil, err + log.Warningf(ctx, "failed to get span recordings from instance %s: %v", instanceID, err) + return nil } - var res []tracingpb.Recording + res := make([]tracingpb.Recording, 0, len(resp.Recordings)) for _, recording := range resp.Recordings { if recording.RecordedSpans == nil { continue @@ -202,5 +181,5 @@ func (t *TraceCollector) getTraceSpanRecordingsForNode( return res[i][0].StartTime.Before(res[j][0].StartTime) }) - return res, nil + return res } diff --git a/pkg/util/tracing/collector/collector_test.go b/pkg/util/tracing/collector/collector_test.go index 8fbf28f08a09..da9f9e921570 100644 --- a/pkg/util/tracing/collector/collector_test.go +++ b/pkg/util/tracing/collector/collector_test.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc/nodedialer" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" @@ -120,6 +121,7 @@ func setupTraces(t1, t2 *tracing.Tracer) (tracingpb.TraceID, tracingpb.TraceID, func TestTracingCollectorGetSpanRecordings(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -132,34 +134,31 @@ func TestTracingCollectorGetSpanRecordings(t *testing.T) { traceCollector := collector.New( localTracer, - func(ctx context.Context) ([]roachpb.NodeID, error) { - nodeIDs := make([]roachpb.NodeID, len(tc.Servers)) + func(ctx context.Context) ([]sqlinstance.InstanceInfo, error) { + instanceIDs := make([]sqlinstance.InstanceInfo, len(tc.Servers)) for i := range tc.Servers { - nodeIDs[i] = tc.Server(i).NodeID() + instanceIDs[i].InstanceID = tc.Server(i).SQLInstanceID() } - return nodeIDs, nil + return instanceIDs, nil }, tc.Server(0).NodeDialer().(*nodedialer.Dialer)) localTraceID, remoteTraceID, cleanup := setupTraces(localTracer, remoteTracer) defer cleanup() - getSpansFromAllNodes := func(traceID tracingpb.TraceID) map[roachpb.NodeID][]tracingpb.Recording { - res := make(map[roachpb.NodeID][]tracingpb.Recording) - - var iter *collector.Iterator - var err error - for iter, err = traceCollector.StartIter(ctx, traceID); err == nil && iter.Valid(); iter.Next(ctx) { - nodeID, recording := iter.Value() - res[nodeID] = append(res[nodeID], recording) - } + getSpansFromAllInstances := func(traceID tracingpb.TraceID) map[base.SQLInstanceID][]tracingpb.Recording { + res := make(map[base.SQLInstanceID][]tracingpb.Recording) + iter, err := traceCollector.StartIter(ctx, traceID) require.NoError(t, err) - require.NoError(t, iter.Error()) + for ; iter.Valid(); iter.Next(ctx) { + instanceID, recording := iter.Value() + res[instanceID] = append(res[instanceID], recording) + } return res } t.Run("fetch-local-recordings", func(t *testing.T) { - nodeRecordings := getSpansFromAllNodes(localTraceID) - node1Recordings := nodeRecordings[roachpb.NodeID(1)] + nodeRecordings := getSpansFromAllInstances(localTraceID) + node1Recordings := nodeRecordings[tc.Server(0).SQLInstanceID()] require.Equal(t, 1, len(node1Recordings)) require.NoError(t, tracing.CheckRecordedSpans(node1Recordings[0], ` span: root @@ -170,7 +169,7 @@ func TestTracingCollectorGetSpanRecordings(t *testing.T) { span: root.child.remotechilddone tags: _verbose=1 `)) - node2Recordings := nodeRecordings[roachpb.NodeID(2)] + node2Recordings := nodeRecordings[tc.Server(1).SQLInstanceID()] require.Equal(t, 1, len(node2Recordings)) require.NoError(t, tracing.CheckRecordedSpans(node2Recordings[0], ` span: root.child.remotechild @@ -182,8 +181,8 @@ func TestTracingCollectorGetSpanRecordings(t *testing.T) { // The traceCollector is running on node 1, so most of the recordings for this // subtest will be passed back by node 2 over RPC. t.Run("fetch-remote-recordings", func(t *testing.T) { - nodeRecordings := getSpansFromAllNodes(remoteTraceID) - node1Recordings := nodeRecordings[roachpb.NodeID(1)] + nodeRecordings := getSpansFromAllInstances(remoteTraceID) + node1Recordings := nodeRecordings[tc.Server(0).SQLInstanceID()] require.Equal(t, 2, len(node1Recordings)) require.NoError(t, tracing.CheckRecordedSpans(node1Recordings[0], ` span: root2.child.remotechild @@ -194,7 +193,7 @@ func TestTracingCollectorGetSpanRecordings(t *testing.T) { tags: _unfinished=1 _verbose=1 `)) - node2Recordings := nodeRecordings[roachpb.NodeID(2)] + node2Recordings := nodeRecordings[tc.Server(1).SQLInstanceID()] require.Equal(t, 1, len(node2Recordings)) require.NoError(t, tracing.CheckRecordedSpans(node2Recordings[0], ` span: root2