Skip to content

Commit

Permalink
Update resolving SSH hosts to use unified API (#45644)
Browse files Browse the repository at this point in the history
A new GetAllUnifiedResources function was added to replace
GetAllResources so that ssh host resolution during tsh ssh used
the more performant API. Additionally, the fallback clientside
filtering in GetSSHTargets that used GetAllResources was removed
as all supported auth instances should implement the GetSSHTargets
RPC.
  • Loading branch information
rosstimothy authored Aug 21, 2024
1 parent a18f674 commit 410207b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 33 deletions.
56 changes: 26 additions & 30 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3788,6 +3788,31 @@ func GetUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient,
}
}

// GetAllUnifiedResources is a helper for getting all existing resources that match the provided request. In addition to
// iterating pages, it also correctly handles downsizing pages when LimitExceeded errors are encountered.
func GetAllUnifiedResources(ctx context.Context, clt ListUnifiedResourcesClient, req *proto.ListUnifiedResourcesRequest) ([]*types.EnrichedResource, error) {
var out []*types.EnrichedResource

// Set the limit to the default size.
req.Limit = int32(defaults.DefaultChunkSize)
for {
resources, nextKey, err := GetUnifiedResourcePage(ctx, clt, req)
if err != nil {
return nil, trace.Wrap(err)
}

out = append(out, resources...)

if nextKey == "" || len(resources) == 0 {
break
}

req.StartKey = nextKey
}

return out, nil
}

// GetEnrichedResourcePage is a helper for getting a single page of enriched resources.
func GetEnrichedResourcePage(ctx context.Context, clt GetResourcesClient, req *proto.ListResourcesRequest) (ResourcePage[*types.EnrichedResource], error) {
var out ResourcePage[*types.EnrichedResource]
Expand Down Expand Up @@ -4064,36 +4089,7 @@ func GetKubernetesResourcesWithFilters(ctx context.Context, clt kubeproto.KubeSe
// but may result in confusing behavior if it is used outside of those contexts.
func (c *Client) GetSSHTargets(ctx context.Context, req *proto.GetSSHTargetsRequest) (*proto.GetSSHTargetsResponse, error) {
rsp, err := c.grpc.GetSSHTargets(ctx, req)
if err := trace.Wrap(err); !trace.IsNotImplemented(err) {
return rsp, err
}

// if we got a not implemented error, fallback to client-side filtering
servers, err := GetAllResources[*types.ServerV2](ctx, c, &proto.ListResourcesRequest{
ResourceType: types.KindNode,
UseSearchAsRoles: true,
})
if err != nil {
return nil, trace.Wrap(err)
}

// we only get here if we hit a NotImplementedError from GetSSHTargets, which means
// we should be performing client-side filtering with default parameters instead.
routeMatcher := utils.NewSSHRouteMatcher(req.Host, req.Port, false)

// do client-side filtering
filtered := servers[:0]
for _, srv := range servers {
if !routeMatcher.RouteToServer(srv) {
continue
}

filtered = append(filtered, srv)
}

return &proto.GetSSHTargetsResponse{
Servers: filtered,
}, nil
return rsp, trace.Wrap(err)
}

// CreateSessionTracker creates a tracker resource for an active session.
Expand Down
18 changes: 15 additions & 3 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ type targetNode struct {

// getTargetNodes returns a list of node addresses this SSH command needs to
// operate on.
func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.GetResourcesClient, options SSHOptions) ([]targetNode, error) {
func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]targetNode, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/getTargetNodes",
Expand All @@ -1393,16 +1393,28 @@ func (tc *TeleportClient) getTargetNodes(ctx context.Context, clt client.GetReso

// Query for nodes if labels, fuzzy search, or predicate expressions were provided.
if len(tc.Labels) > 0 || len(tc.SearchKeywords) > 0 || tc.PredicateExpression != "" {
nodes, err := client.GetAllResources[types.Server](ctx, clt, tc.ResourceFilter(types.KindNode))
nodes, err := client.GetAllUnifiedResources(ctx, clt, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindNode},
SortBy: types.SortBy{Field: types.ResourceMetadataName},
Labels: tc.Labels,
SearchKeywords: tc.SearchKeywords,
PredicateExpression: tc.PredicateExpression,
UseSearchAsRoles: tc.UseSearchAsRoles,
})
if err != nil {
return nil, trace.Wrap(err)
}

retval := make([]targetNode, 0, len(nodes))
for _, resource := range nodes {
server, ok := resource.ResourceWithLabels.(types.Server)
if !ok {
continue
}

// always dial nodes by UUID
retval = append(retval, targetNode{
hostname: resource.GetHostname(),
hostname: server.GetHostname(),
addr: fmt.Sprintf("%s:0", resource.GetName()),
})
}
Expand Down
9 changes: 9 additions & 0 deletions lib/client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,15 @@ func (f fakeResourceClient) GetResources(ctx context.Context, req *proto.ListRes
return &proto.ListResourcesResponse{Resources: out}, nil
}

func (f fakeResourceClient) ListUnifiedResources(ctx context.Context, req *proto.ListUnifiedResourcesRequest) (*proto.ListUnifiedResourcesResponse, error) {
out := make([]*proto.PaginatedResource, 0, len(f.nodes))
for _, n := range f.nodes {
out = append(out, &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: n}})
}

return &proto.ListUnifiedResourcesResponse{Resources: out}, nil
}

func TestGetTargetNodes(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit 410207b

Please sign in to comment.