diff --git a/internal/app/cache/cache.go b/internal/app/cache/cache.go index 3d411387..e70d6883 100644 --- a/internal/app/cache/cache.go +++ b/internal/app/cache/cache.go @@ -113,7 +113,7 @@ func (c *cache) Fetch(key string) (*Resource, error) { func (c *cache) SetResponse(key string, resp v2.DiscoveryResponse) (map[*v2.DiscoveryRequest]bool, error) { c.cacheMu.Lock() defer c.cacheMu.Unlock() - marshaledResources, err := marshalResources(resp.Resources) + marshaledResources, err := MarshalResources(resp.Resources) if err != nil { return nil, fmt.Errorf("failed to marshal resources for key: %s, err %v", key, err) } @@ -126,6 +126,7 @@ func (c *cache) SetResponse(key string, resp v2.DiscoveryResponse) (map[*v2.Disc resource := Resource{ Resp: response, ExpirationTime: c.getExpirationTime(time.Now()), + Requests: make(map[*v2.DiscoveryRequest]bool), } c.cache.Add(key, resource) return nil, nil @@ -193,7 +194,9 @@ func (c *cache) getExpirationTime(currentTime time.Time) time.Time { return time.Time{} } -func marshalResources(resources []*any.Any) ([]gcp_types.MarshaledResource, error) { +// MarshalResource converts the raw xDS discovery resources into a serialized +// form accepted by go-control-plane. +func MarshalResources(resources []*any.Any) ([]gcp_types.MarshaledResource, error) { var marshaledResources []gcp_types.MarshaledResource for _, resource := range resources { marshaledResource, err := gcp.MarshalResource(resource) diff --git a/internal/app/cache/cache_test.go b/internal/app/cache/cache_test.go index 65395a09..3659937b 100644 --- a/internal/app/cache/cache_test.go +++ b/internal/app/cache/cache_test.go @@ -74,7 +74,8 @@ var testResponse = Response{ } var testResource = Resource{ - Resp: &testResponse, + Resp: &testResponse, + Requests: make(map[*v2.DiscoveryRequest]bool), } func TestAddRequestAndFetch(t *testing.T) { diff --git a/internal/app/orchestrator/downstream.go b/internal/app/orchestrator/downstream.go new file mode 100644 index 00000000..bf11d268 --- /dev/null +++ b/internal/app/orchestrator/downstream.go @@ -0,0 +1,80 @@ +// Package orchestrator is responsible for instrumenting inbound xDS client +// requests to the correct aggregated key, forwarding a representative request +// to the upstream origin server, and managing the lifecycle of downstream and +// upstream connections and associates streams. It implements +// go-control-plane's Cache interface in order to receive xDS-based requests, +// send responses, and handle gRPC streams. +// +// This file manages the bookkeeping of downstream clients by tracking inbound +// requests to their corresponding response channels. The contents of this file +// are intended to only be used within the orchestrator module and should not +// be exported. +package orchestrator + +import ( + "sync" + + gcp "github.com/envoyproxy/go-control-plane/pkg/cache/v2" +) + +// downstreamResponseMap is a map of downstream xDS client requests to response +// channels. +type downstreamResponseMap struct { + mu sync.RWMutex + responseChannels map[*gcp.Request]chan gcp.Response +} + +func newDownstreamResponseMap() downstreamResponseMap { + return downstreamResponseMap{ + responseChannels: make(map[*gcp.Request]chan gcp.Response), + } +} + +// createChannel initializes a new channel for a request if it doesn't already +// exist. +func (d *downstreamResponseMap) createChannel(req *gcp.Request) chan gcp.Response { + d.mu.Lock() + defer d.mu.Unlock() + if _, ok := d.responseChannels[req]; !ok { + d.responseChannels[req] = make(chan gcp.Response, 1) + } + return d.responseChannels[req] +} + +// get retrieves the channel where responses are set for the specified request. +func (d *downstreamResponseMap) get(req *gcp.Request) (chan gcp.Response, bool) { + d.mu.RLock() + defer d.mu.RUnlock() + channel, ok := d.responseChannels[req] + return channel, ok +} + +// delete removes the response channel and request entry from the map. +// Note: We don't close the response channel prior to deletion because there +// can be separate go routines that are still attempting to write to the +// channel. We rely on garbage collection to clean up and close outstanding +// response channels once the go routines finish writing to them. +func (d *downstreamResponseMap) delete(req *gcp.Request) chan gcp.Response { + d.mu.Lock() + defer d.mu.Unlock() + if channel, ok := d.responseChannels[req]; ok { + delete(d.responseChannels, req) + return channel + } + return nil +} + +// deleteAll removes all response channels and request entries from the map. +// Note: We don't close the response channel prior to deletion because there +// can be separate go routines that are still attempting to write to the +// channel. We rely on garbage collection to clean up and close outstanding +// response channels once the go routines finish writing to them. +func (d *downstreamResponseMap) deleteAll(watchers map[*gcp.Request]bool) { + d.mu.Lock() + defer d.mu.Unlock() + for watch := range watchers { + if d.responseChannels[watch] != nil { + delete(d.responseChannels, watch) + } + } +} diff --git a/internal/app/orchestrator/orchestrator.go b/internal/app/orchestrator/orchestrator.go index d1cba0ca..7b7be72f 100644 --- a/internal/app/orchestrator/orchestrator.go +++ b/internal/app/orchestrator/orchestrator.go @@ -9,6 +9,7 @@ package orchestrator import ( "context" "fmt" + "sync" "time" bootstrapv1 "github.com/envoyproxy/xds-relay/pkg/api/bootstrap/v1" @@ -24,6 +25,10 @@ import ( const ( component = "orchestrator" + + // unaggregatedPrefix is the prefix used to label discovery requests that + // could not be successfully mapped to an aggregation rule. + unaggregatedPrefix = "unaggregated_" ) // Orchestrator has the following responsibilities: @@ -48,6 +53,10 @@ const ( // more details. type Orchestrator interface { gcp.Cache + + // This is called by the main shutdown handler and tests to clean up + // open channels. + shutdown(ctx context.Context) } type orchestrator struct { @@ -56,31 +65,42 @@ type orchestrator struct { upstreamClient upstream.Client logger log.Logger + + downstreamResponseMap downstreamResponseMap + upstreamResponseMap upstreamResponseMap } // New instantiates the mapper, cache, upstream client components necessary for // the orchestrator to operate and returns an instance of the instantiated // orchestrator. -func New(ctx context.Context, +func New( + ctx context.Context, l log.Logger, mapper mapper.Mapper, upstreamClient upstream.Client, - cacheConfig *bootstrapv1.Cache) Orchestrator { + cacheConfig *bootstrapv1.Cache, +) Orchestrator { orchestrator := &orchestrator{ - logger: l.Named(component), - mapper: mapper, - upstreamClient: upstreamClient, + logger: l.Named(component), + mapper: mapper, + upstreamClient: upstreamClient, + downstreamResponseMap: newDownstreamResponseMap(), + upstreamResponseMap: newUpstreamResponseMap(), } // Initialize cache. - cache, err := cache.NewCache(int(cacheConfig.MaxEntries), + cache, err := cache.NewCache( + int(cacheConfig.MaxEntries), orchestrator.onCacheEvicted, - time.Duration(cacheConfig.Ttl.Nanos)*time.Nanosecond) + time.Duration(cacheConfig.Ttl.Nanos)*time.Nanosecond, + ) if err != nil { orchestrator.logger.With("error", err).Panic(ctx, "failed to initialize cache") } orchestrator.cache = cache + go orchestrator.shutdown(ctx) + return orchestrator } @@ -95,16 +115,219 @@ func New(ctx context.Context, // // Cancel is an optional function to release resources in the producer. If // provided, the consumer may call this function multiple times. -func (c *orchestrator) CreateWatch(req gcp.Request) (chan gcp.Response, func()) { - // TODO implement. - return nil, nil +func (o *orchestrator) CreateWatch(req gcp.Request) (chan gcp.Response, func()) { + ctx := context.Background() + + // If this is the first time we're seeing the request from the + // downstream client, initialize a channel to feed future responses. + responseChannel := o.downstreamResponseMap.createChannel(&req) + + aggregatedKey, err := o.mapper.GetKey(req) + if err != nil { + // Can't map the request to an aggregated key. Log and continue to + // propagate the response upstream without aggregation. + o.logger.With("err", err).With("req node", req.GetNode()).Warn(ctx, "failed to map to aggregated key") + // Mimic the aggregated key. + // TODO (https://github.com/envoyproxy/xds-relay/issues/56). This key + // needs to be made more granular to uniquely identify a request. + aggregatedKey = fmt.Sprintf("%s%s_%s", unaggregatedPrefix, req.GetNode().GetId(), req.GetTypeUrl()) + } + + // Register the watch for future responses. + err = o.cache.AddRequest(aggregatedKey, &req) + if err != nil { + // If we fail to register the watch, we need to kill this stream by + // closing the response channel. + o.logger.With("err", err).With("key", aggregatedKey).With( + "req node", req.GetNode()).Error(ctx, "failed to add watch") + closedChannel := o.downstreamResponseMap.delete(&req) + return closedChannel, nil + } + + // Check if we have a cached response first. + cached, err := o.cache.Fetch(aggregatedKey) + if err != nil { + // Log, and continue to propagate the response upstream. + o.logger.With("err", err).With("key", aggregatedKey).Warn(ctx, "failed to fetch aggregated key") + } + + if cached != nil && cached.Resp != nil && cached.Resp.Raw.GetVersionInfo() != req.GetVersionInfo() { + // If we have a cached response and the version is different, + // immediately push the result to the response channel. + go func() { responseChannel <- convertToGcpResponse(cached.Resp, req) }() + } + + // Check if we have a upstream stream open for this aggregated key. If not, + // open a stream with the representative request. + if !o.upstreamResponseMap.exists(aggregatedKey) { + upstreamResponseChan, shutdown, err := o.upstreamClient.OpenStream(req) + if err != nil { + // TODO implement retry/back-off logic on error scenario. + // https://github.com/envoyproxy/xds-relay/issues/68 + o.logger.With("err", err).With("key", aggregatedKey).Error(ctx, "Failed to open stream to origin server") + } else { + respChannel, upstreamOpenedPreviously := o.upstreamResponseMap.add(aggregatedKey, upstreamResponseChan) + if upstreamOpenedPreviously { + // A stream was opened previously due to a race between + // concurrent downstreams for the same aggregated key, between + // exists and add operations. In this event, simply close the + // slower stream and return the existing one. + shutdown() + } else { + // Spin up a go routine to watch for upstream responses. + // One routine is opened per aggregate key. + go o.watchUpstream(ctx, aggregatedKey, respChannel.response, respChannel.done, shutdown) + } + } + } + + return responseChannel, o.onCancelWatch(aggregatedKey, &req) } // Fetch implements the polling method of the config cache using a non-empty request. -func (c *orchestrator) Fetch(context.Context, discovery.DiscoveryRequest) (*gcp.Response, error) { +func (o *orchestrator) Fetch(context.Context, discovery.DiscoveryRequest) (*gcp.Response, error) { return nil, fmt.Errorf("Not implemented") } -func (c *orchestrator) onCacheEvicted(key string, resource cache.Resource) { - // TODO implement. +// watchUpstream is intended to be called in a go routine, to receive incoming +// responses, cache the response, and fan out to downstream clients or +// "watchers". There is a corresponding go routine for each aggregated key. +// +// This goroutine continually listens for upstream responses from the passed +// `responseChannel`. For each response, we will: +// - cache this latest response, replacing the previous stale response. +// - retrieve the downstream watchers from the cache for this `aggregated key`. +// - trigger the fanout process to downstream watchers by pushing to the +// individual downstream response channels in separate go routines. +// +// Additionally this function tracks a `done` channel and a `shutdownUpstream` +// function. `done` is a channel that gets closed in two places: +// 1. when server shutdown is triggered. See the `shutdown` function in this +// file for more information. +// 2. when cache TTL expires for this aggregated key. See the `onCacheEvicted` +// function in this file for more information. +// When the `done` channel is closed, we call the `shutdownUpstream` callback +// function. This will signify to the upstream client that we no longer require +// responses from this stream because the downstream connections have been +// terminated. The upstream client will clean up the stream accordingly. +func (o *orchestrator) watchUpstream( + ctx context.Context, + aggregatedKey string, + responseChannel <-chan *discovery.DiscoveryResponse, + done <-chan bool, + shutdownUpstream func(), +) { + for { + select { + case x, more := <-responseChannel: + if !more { + // A problem occurred fetching the response upstream, log retry. + // TODO implement retry/back-off logic on error scenario. + // https://github.com/envoyproxy/xds-relay/issues/68 + o.logger.With("key", aggregatedKey).Error(ctx, "upstream error") + return + } + // Cache the response. + _, err := o.cache.SetResponse(aggregatedKey, *x) + if err != nil { + // TODO if set fails, we may need to retry upstream as well. + // Currently the fallback is to rely on a future response, but + // that probably isn't ideal. + // https://github.com/envoyproxy/xds-relay/issues/70 + // + // If we fail to cache the new response, log and return the old one. + o.logger.With("err", err).With("key", aggregatedKey). + Error(ctx, "Failed to cache the response") + } + + // Get downstream watchers and fan out. + // We retrieve from cache rather than directly fanning out the + // newly received response because the cache does additional + // resource serialization. + cached, err := o.cache.Fetch(aggregatedKey) + if err != nil { + o.logger.With("err", err).With("key", aggregatedKey).Error(ctx, "cache fetch failed") + // Can't do anything because we don't know who the watchers + // are. Drop the response. + } else { + if cached == nil || cached.Resp == nil { + // If cache is empty, there is nothing to fan out. + // Error. Sanity check. Shouldn't ever reach this since we + // just set the response, but it's a rare scenario that can + // happen if the cache TTL is set very short. + o.logger.With("key", aggregatedKey).Error(ctx, "attempted to fan out with no cached response") + } else { + // Goldenpath. + o.fanout(cached.Resp, cached.Requests, aggregatedKey) + } + } + case <-done: + // Exit when signaled that the stream has closed. + shutdownUpstream() + return + } + } +} + +// fanout pushes the response to the response channels of all open downstream +// watchers in parallel. +func (o *orchestrator) fanout(resp *cache.Response, watchers map[*gcp.Request]bool, aggregatedKey string) { + var wg sync.WaitGroup + for watch := range watchers { + wg.Add(1) + go func(watch *gcp.Request) { + defer wg.Done() + if channel, ok := o.downstreamResponseMap.get(watch); ok { + select { + case channel <- convertToGcpResponse(resp, *watch): + break + default: + // If the channel is blocked, we simply drop subsequent requests and error. + // Alternative possibilities are discussed here: + // https://github.com/envoyproxy/xds-relay/pull/53#discussion_r420325553 + o.logger.With("key", aggregatedKey).With("node ID", watch.GetNode().GetId()). + Error(context.Background(), "channel blocked during fanout") + } + } + }(watch) + } + // Wait for all fanouts to complete. + wg.Wait() +} + +// onCacheEvicted is called when the cache evicts a response due to TTL or +// other reasons. When this happens, we need to clean up open streams. +// We shut down both the downstream watchers and the upstream stream. +func (o *orchestrator) onCacheEvicted(key string, resource cache.Resource) { + // TODO Potential for improvements here to handle the thundering herd + // problem: https://github.com/envoyproxy/xds-relay/issues/71 + o.downstreamResponseMap.deleteAll(resource.Requests) + o.upstreamResponseMap.delete(key) +} + +// onCancelWatch cleans up the cached watch when called. +func (o *orchestrator) onCancelWatch(aggregatedKey string, req *gcp.Request) func() { + return func() { + o.downstreamResponseMap.delete(req) + if err := o.cache.DeleteRequest(aggregatedKey, req); err != nil { + o.logger.With("key", aggregatedKey).With("err", err).Warn(context.Background(), "Failed to delete from cache") + } + } +} + +// shutdown closes all upstream connections when ctx.Done is called. +func (o *orchestrator) shutdown(ctx context.Context) { + <-ctx.Done() + o.upstreamResponseMap.deleteAll() +} + +// convertToGcpResponse constructs the go-control-plane response from the +// cached response. +func convertToGcpResponse(resp *cache.Response, req gcp.Request) gcp.Response { + return gcp.Response{ + Request: req, + Version: resp.Raw.GetVersionInfo(), + ResourceMarshaled: true, + MarshaledResources: resp.MarshaledResources, + } } diff --git a/internal/app/orchestrator/orchestrator_test.go b/internal/app/orchestrator/orchestrator_test.go index 7fd91c9c..87a55864 100644 --- a/internal/app/orchestrator/orchestrator_test.go +++ b/internal/app/orchestrator/orchestrator_test.go @@ -2,19 +2,97 @@ package orchestrator import ( "context" + "io/ioutil" "testing" + "time" + v2 "github.com/envoyproxy/go-control-plane/envoy/api/v2" + v2_core "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + gcp "github.com/envoyproxy/go-control-plane/pkg/cache/v2" + "github.com/golang/protobuf/ptypes/any" + "github.com/golang/protobuf/ptypes/duration" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/envoyproxy/xds-relay/internal/app/cache" "github.com/envoyproxy/xds-relay/internal/app/mapper" "github.com/envoyproxy/xds-relay/internal/app/upstream" upstream_mock "github.com/envoyproxy/xds-relay/internal/app/upstream/mock" "github.com/envoyproxy/xds-relay/internal/pkg/log" + "github.com/envoyproxy/xds-relay/internal/pkg/util/testutils" + "github.com/envoyproxy/xds-relay/internal/pkg/util/yamlproto" + aggregationv1 "github.com/envoyproxy/xds-relay/pkg/api/aggregation/v1" bootstrapv1 "github.com/envoyproxy/xds-relay/pkg/api/bootstrap/v1" - "github.com/golang/protobuf/ptypes/duration" +) - aggregationv1 "github.com/envoyproxy/xds-relay/pkg/api/aggregation/v1" +type mockSimpleUpstreamClient struct { + responseChan <-chan *v2.DiscoveryResponse +} - "github.com/stretchr/testify/assert" -) +func (m mockSimpleUpstreamClient) OpenStream(req v2.DiscoveryRequest) (<-chan *v2.DiscoveryResponse, func(), error) { + return m.responseChan, func() {}, nil +} + +type mockMultiStreamUpstreamClient struct { + ldsResponseChan <-chan *v2.DiscoveryResponse + cdsResponseChan <-chan *v2.DiscoveryResponse + + t *testing.T + mapper mapper.Mapper +} + +func (m mockMultiStreamUpstreamClient) OpenStream( + req v2.DiscoveryRequest, +) (<-chan *v2.DiscoveryResponse, func(), error) { + aggregatedKey, err := m.mapper.GetKey(req) + assert.NoError(m.t, err) + + if aggregatedKey == "lds" { + return m.ldsResponseChan, func() {}, nil + } else if aggregatedKey == "cds" { + return m.cdsResponseChan, func() {}, nil + } + + m.t.Errorf("Unsupported aggregated key, %s", aggregatedKey) + return nil, func() {}, nil +} + +func newMockOrchestrator(t *testing.T, mapper mapper.Mapper, upstreamClient upstream.Client) *orchestrator { + orchestrator := &orchestrator{ + logger: log.New("info"), + mapper: mapper, + upstreamClient: upstreamClient, + downstreamResponseMap: newDownstreamResponseMap(), + upstreamResponseMap: newUpstreamResponseMap(), + } + + cache, err := cache.NewCache(1000, orchestrator.onCacheEvicted, 10*time.Second) + assert.NoError(t, err) + orchestrator.cache = cache + + return orchestrator +} + +func newMockMapper(t *testing.T) mapper.Mapper { + bytes, err := ioutil.ReadFile("testdata/aggregation_rules.yaml") // key on request type + assert.NoError(t, err) + + var config aggregationv1.KeyerConfiguration + err = yamlproto.FromYAMLToKeyerConfiguration(string(bytes), &config) + assert.NoError(t, err) + + return mapper.NewMapper(&config) +} + +func assertEqualResources(t *testing.T, got gcp.Response, expected v2.DiscoveryResponse, req gcp.Request) { + expectedResources, err := cache.MarshalResources(expected.Resources) + assert.NoError(t, err) + expectedResponse := cache.Response{ + Raw: expected, + MarshaledResources: expectedResources, + } + assert.Equal(t, convertToGcpResponse(&expectedResponse, req), got) +} func TestNew(t *testing.T) { // Trivial test to ensure orchestrator instantiates. @@ -23,7 +101,8 @@ func TestNew(t *testing.T) { upstream.CallOptions{}, nil, nil, - func(m interface{}) error { return nil }) + func(m interface{}) error { return nil }, + ) config := aggregationv1.KeyerConfiguration{ Fragments: []*aggregationv1.KeyerConfiguration_Fragment{ @@ -44,3 +123,236 @@ func TestNew(t *testing.T) { orchestrator := New(context.Background(), log.New("info"), requestMapper, upstreamClient, &cacheConfig) assert.NotNil(t, orchestrator) } + +func TestGoldenPath(t *testing.T) { + upstreamResponseChannel := make(chan *v2.DiscoveryResponse) + mapper := newMockMapper(t) + orchestrator := newMockOrchestrator( + t, + mapper, + mockSimpleUpstreamClient{ + responseChan: upstreamResponseChannel, + }, + ) + assert.NotNil(t, orchestrator) + + req := gcp.Request{ + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + } + + respChannel, cancelWatch := orchestrator.CreateWatch(req) + assert.NotNil(t, respChannel) + assert.Equal(t, 1, len(orchestrator.downstreamResponseMap.responseChannels)) + testutils.AssertSyncMapLen(t, 1, orchestrator.upstreamResponseMap.internal) + orchestrator.upstreamResponseMap.internal.Range(func(key, val interface{}) bool { + assert.Equal(t, "lds", key.(string)) + return true + }) + + resp := v2.DiscoveryResponse{ + VersionInfo: "1", + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + Resources: []*any.Any{ + &anypb.Any{ + Value: []byte("lds resource"), + }, + }, + } + upstreamResponseChannel <- &resp + + gotResponse := <-respChannel + assertEqualResources(t, gotResponse, resp, req) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + orchestrator.shutdown(ctx) + testutils.AssertSyncMapLen(t, 0, orchestrator.upstreamResponseMap.internal) + + cancelWatch() + assert.Equal(t, 0, len(orchestrator.downstreamResponseMap.responseChannels)) +} + +func TestCachedResponse(t *testing.T) { + upstreamResponseChannel := make(chan *v2.DiscoveryResponse) + mapper := newMockMapper(t) + orchestrator := newMockOrchestrator( + t, + mapper, + mockSimpleUpstreamClient{ + responseChan: upstreamResponseChannel, + }, + ) + assert.NotNil(t, orchestrator) + + // Test scenario with different request and response versions. + // Version is different, so we expect a response. + req := gcp.Request{ + VersionInfo: "0", + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + } + + aggregatedKey, err := mapper.GetKey(req) + assert.NoError(t, err) + mockResponse := v2.DiscoveryResponse{ + VersionInfo: "1", + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + Resources: []*any.Any{ + &anypb.Any{ + Value: []byte("lds resource"), + }, + }, + } + watchers, err := orchestrator.cache.SetResponse(aggregatedKey, mockResponse) + assert.NoError(t, err) + assert.Equal(t, 0, len(watchers)) + + respChannel, cancelWatch := orchestrator.CreateWatch(req) + assert.NotNil(t, respChannel) + assert.Equal(t, 1, len(orchestrator.downstreamResponseMap.responseChannels)) + testutils.AssertSyncMapLen(t, 1, orchestrator.upstreamResponseMap.internal) + orchestrator.upstreamResponseMap.internal.Range(func(key, val interface{}) bool { + assert.Equal(t, "lds", key.(string)) + return true + }) + + gotResponse := <-respChannel + assertEqualResources(t, gotResponse, mockResponse, req) + + // Attempt pushing a more recent response from upstream. + resp := v2.DiscoveryResponse{ + VersionInfo: "2", + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + Resources: []*any.Any{ + &anypb.Any{ + Value: []byte("some other lds resource"), + }, + }, + } + + upstreamResponseChannel <- &resp + gotResponse = <-respChannel + assertEqualResources(t, gotResponse, resp, req) + testutils.AssertSyncMapLen(t, 1, orchestrator.upstreamResponseMap.internal) + orchestrator.upstreamResponseMap.internal.Range(func(key, val interface{}) bool { + assert.Contains(t, "lds", key.(string)) + return true + }) + + // Test scenario with same request and response version. + // We expect a watch to be open but no response. + req2 := gcp.Request{ + VersionInfo: "2", + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + } + + respChannel2, cancelWatch2 := orchestrator.CreateWatch(req2) + assert.NotNil(t, respChannel2) + assert.Equal(t, 2, len(orchestrator.downstreamResponseMap.responseChannels)) + testutils.AssertSyncMapLen(t, 1, orchestrator.upstreamResponseMap.internal) + orchestrator.upstreamResponseMap.internal.Range(func(key, val interface{}) bool { + assert.Contains(t, "lds", key.(string)) + return true + }) + + // If we pass this point, it's safe to assume the respChannel2 is empty, + // otherwise the test would block and not complete. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + orchestrator.shutdown(ctx) + testutils.AssertSyncMapLen(t, 0, orchestrator.upstreamResponseMap.internal) + + cancelWatch() + assert.Equal(t, 1, len(orchestrator.downstreamResponseMap.responseChannels)) + cancelWatch2() + assert.Equal(t, 0, len(orchestrator.downstreamResponseMap.responseChannels)) +} + +func TestMultipleWatchersAndUpstreams(t *testing.T) { + upstreamResponseChannelLDS := make(chan *v2.DiscoveryResponse) + upstreamResponseChannelCDS := make(chan *v2.DiscoveryResponse) + mapper := newMockMapper(t) + orchestrator := newMockOrchestrator( + t, + mapper, + mockMultiStreamUpstreamClient{ + ldsResponseChan: upstreamResponseChannelLDS, + cdsResponseChan: upstreamResponseChannelCDS, + mapper: mapper, + t: t, + }, + ) + assert.NotNil(t, orchestrator) + + req1 := gcp.Request{ + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + Node: &v2_core.Node{ + Id: "req1", + }, + } + req2 := gcp.Request{ + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + Node: &v2_core.Node{ + Id: "req2", + }, + } + req3 := gcp.Request{ + TypeUrl: "type.googleapis.com/envoy.api.v2.Cluster", + Node: &v2_core.Node{ + Id: "req3", + }, + } + + respChannel1, cancelWatch1 := orchestrator.CreateWatch(req1) + assert.NotNil(t, respChannel1) + respChannel2, cancelWatch2 := orchestrator.CreateWatch(req2) + assert.NotNil(t, respChannel2) + respChannel3, cancelWatch3 := orchestrator.CreateWatch(req3) + assert.NotNil(t, respChannel3) + + upstreamResponseLDS := v2.DiscoveryResponse{ + VersionInfo: "1", + TypeUrl: "type.googleapis.com/envoy.api.v2.Listener", + Resources: []*any.Any{ + &anypb.Any{ + Value: []byte("lds resource"), + }, + }, + } + upstreamResponseCDS := v2.DiscoveryResponse{ + VersionInfo: "1", + TypeUrl: "type.googleapis.com/envoy.api.v2.Cluster", + Resources: []*any.Any{ + &anypb.Any{ + Value: []byte("cds resource"), + }, + }, + } + + upstreamResponseChannelLDS <- &upstreamResponseLDS + upstreamResponseChannelCDS <- &upstreamResponseCDS + + gotResponseFromChannel1 := <-respChannel1 + gotResponseFromChannel2 := <-respChannel2 + gotResponseFromChannel3 := <-respChannel3 + + assert.Equal(t, 3, len(orchestrator.downstreamResponseMap.responseChannels)) + testutils.AssertSyncMapLen(t, 2, orchestrator.upstreamResponseMap.internal) + orchestrator.upstreamResponseMap.internal.Range(func(key, val interface{}) bool { + assert.Contains(t, []string{"lds", "cds"}, key.(string)) + return true + }) + + assertEqualResources(t, gotResponseFromChannel1, upstreamResponseLDS, req1) + assertEqualResources(t, gotResponseFromChannel2, upstreamResponseLDS, req2) + assertEqualResources(t, gotResponseFromChannel3, upstreamResponseCDS, req3) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + orchestrator.shutdown(ctx) + testutils.AssertSyncMapLen(t, 0, orchestrator.upstreamResponseMap.internal) + + cancelWatch1() + cancelWatch2() + cancelWatch3() + assert.Equal(t, 0, len(orchestrator.downstreamResponseMap.responseChannels)) +} diff --git a/internal/app/orchestrator/testdata/aggregation_rules.yaml b/internal/app/orchestrator/testdata/aggregation_rules.yaml new file mode 100644 index 00000000..0bcfe5fe --- /dev/null +++ b/internal/app/orchestrator/testdata/aggregation_rules.yaml @@ -0,0 +1,26 @@ +fragments: + - rules: + - match: + request_type_match: + types: + - "type.googleapis.com/envoy.api.v2.Listener" + result: + string_fragment: "lds" + - match: + request_type_match: + types: + - "type.googleapis.com/envoy.api.v2.Cluster" + result: + string_fragment: "cds" + - match: + request_type_match: + types: + - "type.googleapis.com/envoy.api.v2.Route" + result: + string_fragment: "rds" + - match: + request_type_match: + types: + - "type.googleapis.com/envoy.api.v2.Endpoint" + result: + string_fragment: "eds" diff --git a/internal/app/orchestrator/upstream.go b/internal/app/orchestrator/upstream.go new file mode 100644 index 00000000..9ab04689 --- /dev/null +++ b/internal/app/orchestrator/upstream.go @@ -0,0 +1,84 @@ +// Package orchestrator is responsible for instrumenting inbound xDS client +// requests to the correct aggregated key, forwarding a representative request +// to the upstream origin server, and managing the lifecycle of downstream and +// upstream connections and associates streams. It implements +// go-control-plane's Cache interface in order to receive xDS-based requests, +// send responses, and handle gRPC streams. +// +// This file manages the bookkeeping of upstream responses by tracking the +// aggregated key and its corresponding receiver channel for upstream +// responses. The contents of this file are intended to only be used within the +// orchestrator module and should not be exported. +package orchestrator + +import ( + "sync" + + discovery "github.com/envoyproxy/go-control-plane/envoy/api/v2" +) + +// upstreamResponseMap is the map of aggregate key to the receive-only upstream +// origin server response channels. +// +// sync.Map was chosen due to: +// - support for concurrent locks on a per-key basis +// - stable keys (when a given key is written once but read many times) +// The main drawback is the lack of type support. +type upstreamResponseMap struct { + // This is of type *sync.Map[string]upstreamResponseChannel, where the key + // is the xds-relay aggregated key. + internal *sync.Map +} + +type upstreamResponseChannel struct { + response <-chan *discovery.DiscoveryResponse + done chan bool +} + +func newUpstreamResponseMap() upstreamResponseMap { + return upstreamResponseMap{ + internal: &sync.Map{}, + } +} + +// exists returns true if the aggregatedKey exists. +func (u *upstreamResponseMap) exists(aggregatedKey string) bool { + _, ok := u.internal.Load(aggregatedKey) + return ok +} + +// add sets the response channel for the provided aggregated key. It also +// initializes a done channel to be used during cleanup. +func (u *upstreamResponseMap) add( + aggregatedKey string, + responseChannel <-chan *discovery.DiscoveryResponse, +) (upstreamResponseChannel, bool) { + channel := upstreamResponseChannel{ + response: responseChannel, + done: make(chan bool, 1), + } + result, exists := u.internal.LoadOrStore(aggregatedKey, channel) + return result.(upstreamResponseChannel), exists +} + +// delete signifies closure of the upstream stream and removes the map entry +// for the specified aggregated key. +func (u *upstreamResponseMap) delete(aggregatedKey string) { + if channel, ok := u.internal.Load(aggregatedKey); ok { + close(channel.(upstreamResponseChannel).done) + // The implementation of sync.Map will already check for key existence + // prior to issuing the delete, so we don't need worry about deleting + // a non-existent key due to concurrent race conditions. + u.internal.Delete(aggregatedKey) + } +} + +// deleteAll signifies closure of all upstream streams and removes the map +// entries. This is called during server shutdown. +func (u *upstreamResponseMap) deleteAll() { + u.internal.Range(func(aggregatedKey, channel interface{}) bool { + close(channel.(upstreamResponseChannel).done) + u.internal.Delete(aggregatedKey) + return true + }) +} diff --git a/internal/pkg/util/testutils/syncmap.go b/internal/pkg/util/testutils/syncmap.go new file mode 100644 index 00000000..58cf2266 --- /dev/null +++ b/internal/pkg/util/testutils/syncmap.go @@ -0,0 +1,23 @@ +package testutils + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +// AssertSyncMapLen is a custom solution for checking the length of sync.Map. +// sync.Map does not currently offer support for length checks: +// https://github.com/golang/go/issues/20680 +func AssertSyncMapLen(t *testing.T, len int, sm *sync.Map) { + count := 0 + var mu sync.Mutex + sm.Range(func(key, val interface{}) bool { + mu.Lock() + count++ + mu.Unlock() + return true + }) + assert.Equal(t, count, len) +} diff --git a/internal/pkg/util/testutils/syncmap_test.go b/internal/pkg/util/testutils/syncmap_test.go new file mode 100644 index 00000000..391d8d1a --- /dev/null +++ b/internal/pkg/util/testutils/syncmap_test.go @@ -0,0 +1,18 @@ +package testutils + +import ( + "sync" + "testing" +) + +func TestAssertSyncMapLen(t *testing.T) { + var sm sync.Map + + sm.Store("foo", 1) + sm.Store("foo", 2) + sm.Store("bar", 1) + AssertSyncMapLen(t, 2, &sm) + + sm.Delete("foo") + AssertSyncMapLen(t, 1, &sm) +}