diff --git a/internal/adapters/consul/sync.go b/internal/adapters/consul/sync.go index 1ebb43e3f..7f9b778fb 100644 --- a/internal/adapters/consul/sync.go +++ b/internal/adapters/consul/sync.go @@ -7,15 +7,17 @@ import ( "sync" "time" - "github.com/hashicorp/consul-api-gateway/internal/common" - "github.com/hashicorp/consul-api-gateway/internal/consul" - "github.com/hashicorp/consul-api-gateway/internal/core" "github.com/hashicorp/consul/api" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-multierror" + + "github.com/hashicorp/consul-api-gateway/internal/common" + "github.com/hashicorp/consul-api-gateway/internal/consul" + "github.com/hashicorp/consul-api-gateway/internal/core" ) type syncState struct { + ingress *api.IngressGatewayConfigEntry routers *consul.ConfigEntryIndex splitters *consul.ConfigEntryIndex defaults *consul.ConfigEntryIndex @@ -208,19 +210,20 @@ func discoveryChain(gateway core.ResolvedGateway) (*api.IngressGatewayConfigEntr return ingress, routers, splitters, defaults } -func (a *SyncAdapter) entriesForGateway(id core.GatewayID) (*consul.ConfigEntryIndex, *consul.ConfigEntryIndex, *consul.ConfigEntryIndex) { +func (a *SyncAdapter) entriesForGateway(id core.GatewayID) (*api.IngressGatewayConfigEntry, *consul.ConfigEntryIndex, *consul.ConfigEntryIndex, *consul.ConfigEntryIndex) { existing, found := a.sync[id] if !found { routers := consul.NewConfigEntryIndex(api.ServiceRouter) splitters := consul.NewConfigEntryIndex(api.ServiceSplitter) defaults := consul.NewConfigEntryIndex(api.ServiceDefaults) - return routers, splitters, defaults + return nil, routers, splitters, defaults } - return existing.routers, existing.splitters, existing.defaults + return existing.ingress, existing.routers, existing.splitters, existing.defaults } -func (a *SyncAdapter) setEntriesForGateway(gateway core.ResolvedGateway, routers *consul.ConfigEntryIndex, splitters *consul.ConfigEntryIndex, defaults *consul.ConfigEntryIndex) { +func (a *SyncAdapter) setEntriesForGateway(gateway core.ResolvedGateway, ingress *api.IngressGatewayConfigEntry, routers *consul.ConfigEntryIndex, splitters *consul.ConfigEntryIndex, defaults *consul.ConfigEntryIndex) { a.sync[gateway.ID] = syncState{ + ingress: ingress, routers: routers, splitters: splitters, defaults: defaults, @@ -257,12 +260,7 @@ func (a *SyncAdapter) Clear(ctx context.Context, id core.GatewayID) error { defer a.logger.Trace("entries cleared", "time", time.Now(), "spent", time.Since(started)) } - ingress := &api.IngressGatewayConfigEntry{ - Kind: api.IngressGateway, - Name: id.Service, - Namespace: id.ConsulNamespace, - } - existingRouters, existingSplitters, existingDefaults := a.entriesForGateway(id) + ingress, existingRouters, existingSplitters, existingDefaults := a.entriesForGateway(id) removedRouters := existingRouters.ToArray() removedSplitters := existingSplitters.ToArray() removedDefaults := existingDefaults.ToArray() @@ -299,7 +297,7 @@ func (a *SyncAdapter) Clear(ctx context.Context, id core.GatewayID) error { return nil } -func (a *SyncAdapter) Sync(ctx context.Context, gateway core.ResolvedGateway) error { +func (a *SyncAdapter) Sync(ctx context.Context, gateway core.ResolvedGateway) (bool, error) { a.mutex.Lock() defer a.mutex.Unlock() @@ -314,7 +312,7 @@ func (a *SyncAdapter) Sync(ctx context.Context, gateway core.ResolvedGateway) er } ingress, computedRouters, computedSplitters, computedDefaults := discoveryChain(gateway) - existingRouters, existingSplitters, existingDefaults := a.entriesForGateway(gateway.ID) + _, existingRouters, existingSplitters, existingDefaults := a.entriesForGateway(gateway.ID) // Since we can't make multiple config entry changes in a single transaction we must // perform the operations in a set that is least likely to induce downtime. @@ -330,6 +328,14 @@ func (a *SyncAdapter) Sync(ctx context.Context, gateway core.ResolvedGateway) er removedDefaults := computedDefaults.Difference(existingDefaults).ToArray() if a.logger.IsTrace() { + started := time.Now() + resolved, err := json.MarshalIndent(gateway, "", " ") + if err == nil { + a.logger.Trace("reconciling gateway snapshot", "gateway", string(resolved)) + } + a.logger.Trace("started reconciliation", "time", started) + defer a.logger.Trace("reconciliation finished", "time", time.Now(), "spent", time.Since(started)) + ingressEntry, err := json.MarshalIndent(ingress, "", " ") if err == nil { a.logger.Trace("setting ingress", "items", string(ingressEntry)) @@ -346,33 +352,33 @@ func (a *SyncAdapter) Sync(ctx context.Context, gateway core.ResolvedGateway) er // defaults need to go first, otherwise the routers are always configured to use tcp if err := a.setConfigEntries(ctx, addedDefaults...); err != nil { - return fmt.Errorf("error adding service defaults config entries: %w", err) + return false, fmt.Errorf("error adding service defaults config entries: %w", err) } if err := a.setConfigEntries(ctx, addedRouters...); err != nil { - return fmt.Errorf("error adding service router config entries: %w", err) + return false, fmt.Errorf("error adding service router config entries: %w", err) } if err := a.setConfigEntries(ctx, addedSplitters...); err != nil { - return fmt.Errorf("error adding service splitter config entries: %w", err) + return false, fmt.Errorf("error adding service splitter config entries: %w", err) } if err := a.setConfigEntries(ctx, ingress); err != nil { - return fmt.Errorf("error adding ingress config entry: %w", err) + return false, fmt.Errorf("error adding ingress config entry: %w", err) } if err := a.deleteConfigEntries(ctx, removedRouters...); err != nil { - return fmt.Errorf("error removing service router config entries: %w", err) + return false, fmt.Errorf("error removing service router config entries: %w", err) } if err := a.deleteConfigEntries(ctx, removedSplitters...); err != nil { - return fmt.Errorf("error removing service splitter config entries: %w", err) + return false, fmt.Errorf("error removing service splitter config entries: %w", err) } if err := a.deleteConfigEntries(ctx, removedDefaults...); err != nil { - return fmt.Errorf("error removing service defaults config entries: %w", err) + return false, fmt.Errorf("error removing service defaults config entries: %w", err) } - a.setEntriesForGateway(gateway, computedRouters, computedSplitters, computedDefaults) + a.setEntriesForGateway(gateway, ingress, computedRouters, computedSplitters, computedDefaults) if err := a.syncIntentionsForGateway(gateway.ID, ingress); err != nil { - return fmt.Errorf("error syncing service intention config entries: %w", err) + return false, fmt.Errorf("error syncing service intention config entries: %w", err) } - return nil + return true, nil } diff --git a/internal/adapters/consul/sync_test.go b/internal/adapters/consul/sync_test.go index 28b07a1bc..46477b039 100644 --- a/internal/adapters/consul/sync_test.go +++ b/internal/adapters/consul/sync_test.go @@ -196,7 +196,7 @@ func TestConsulSyncAdapter_Sync(t *testing.T) { }}, } - err = adapter.Sync(ctx, gateway) + _, err = adapter.Sync(ctx, gateway) require.NoError(t, err) require.Eventually(t, func() bool { diff --git a/internal/core/interfaces.go b/internal/core/interfaces.go index fa0a7411f..af16c8662 100644 --- a/internal/core/interfaces.go +++ b/internal/core/interfaces.go @@ -7,6 +7,6 @@ import ( // SyncAdapter is used for synchronizing store state to // an external system type SyncAdapter interface { - Sync(ctx context.Context, gateway ResolvedGateway) error + Sync(ctx context.Context, gateway ResolvedGateway) (bool, error) Clear(ctx context.Context, id GatewayID) error } diff --git a/internal/envoy/handler.go b/internal/envoy/handler.go index 7a2deaf77..21657c91c 100644 --- a/internal/envoy/handler.go +++ b/internal/envoy/handler.go @@ -15,13 +15,14 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/consul-api-gateway/internal/metrics" + "github.com/hashicorp/consul-api-gateway/internal/store" ) // RequestHandler implements the handlers for an SDS Delta server type RequestHandler struct { logger hclog.Logger secretManager SecretManager - registry GatewaySecretRegistry + store store.Store nodeMap sync.Map streamContexts sync.Map activeStreams int64 @@ -29,9 +30,9 @@ type RequestHandler struct { // NewRequestHandler initializes a RequestHandler instance and wraps it in a github.com/envoyproxy/go-control-plane/pkg/server/v3,(*CallbackFuncs) // so that it can be used by the stock go-control-plane server implementation -func NewRequestHandler(logger hclog.Logger, registry GatewaySecretRegistry, secretManager SecretManager) *server.CallbackFuncs { +func NewRequestHandler(logger hclog.Logger, store store.Store, secretManager SecretManager) *server.CallbackFuncs { handler := &RequestHandler{ - registry: registry, + store: store, logger: logger, secretManager: secretManager, } @@ -85,7 +86,12 @@ func (r *RequestHandler) OnStreamRequest(streamID int64, req *discovery.Discover resources := req.GetResourceNames() // check to make sure we're actually authorized to do this - allowed, err := r.registry.CanFetchSecrets(ctx, GatewayFromContext(ctx), resources) + gateway, err := r.store.GetGateway(ctx, GatewayFromContext(ctx)) + if err != nil { + r.logger.Error("error fetching gateway", "error", err) + return err + } + allowed, err := gateway.CanFetchSecrets(ctx, resources) if err != nil { r.logger.Error("error checking gateway secrets", "error", err) return err diff --git a/internal/envoy/handler_test.go b/internal/envoy/handler_test.go index f07ebdb5a..631677e5c 100644 --- a/internal/envoy/handler_test.go +++ b/internal/envoy/handler_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/require" "github.com/hashicorp/consul-api-gateway/internal/envoy/mocks" + storeMocks "github.com/hashicorp/consul-api-gateway/internal/store/mocks" + "github.com/hashicorp/go-hclog" ) @@ -27,9 +29,11 @@ func TestOnStreamRequest(t *testing.T) { defer ctrl.Finish() secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - registry.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any(), requestedSecrets).Return(true, nil) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + gateway := storeMocks.NewMockGateway(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).Return(gateway, nil) + gateway.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any()).Return(true, nil) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) request := &discovery.DiscoveryRequest{ ResourceNames: requestedSecrets, @@ -57,9 +61,11 @@ func TestOnStreamRequest_PermissionError(t *testing.T) { defer ctrl.Finish() secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - registry.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any(), requestedSecrets).Return(false, nil) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + gateway := storeMocks.NewMockGateway(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).Return(gateway, nil) + gateway.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any()).Return(false, nil) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) request := &discovery.DiscoveryRequest{ ResourceNames: requestedSecrets, @@ -87,9 +93,11 @@ func TestOnStreamRequest_SetResourcesForNodeError(t *testing.T) { expectedErr := errors.New("error") secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - registry.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any(), requestedSecrets).Return(true, nil) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + gateway := storeMocks.NewMockGateway(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).Return(gateway, nil) + gateway.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any()).Return(true, nil) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) request := &discovery.DiscoveryRequest{ ResourceNames: requestedSecrets, @@ -117,9 +125,11 @@ func TestOnStreamRequest_Graceful(t *testing.T) { defer ctrl.Finish() secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - registry.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any(), requestedSecrets).Return(true, nil) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + gateway := storeMocks.NewMockGateway(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).Return(gateway, nil) + gateway.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any()).Return(true, nil) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) request := &discovery.DiscoveryRequest{ ResourceNames: requestedSecrets, @@ -146,9 +156,11 @@ func TestOnStreamClosed(t *testing.T) { defer ctrl.Finish() secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - registry.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any(), requestedSecrets).Return(true, nil) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + gateway := storeMocks.NewMockGateway(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).Return(gateway, nil) + gateway.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any()).Return(true, nil) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) request := &discovery.DiscoveryRequest{ ResourceNames: requestedSecrets, @@ -173,8 +185,8 @@ func TestOnStreamClosed_Graceful(t *testing.T) { defer ctrl.Finish() secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) // no-ops instead of panics without setting up the stream context in the open call handler.OnStreamClosed(1) @@ -187,8 +199,8 @@ func TestOnStreamOpen(t *testing.T) { defer ctrl.Finish() secrets := mocks.NewMockSecretManager(ctrl) - registry := mocks.NewMockGatewaySecretRegistry(ctrl) - handler := NewRequestHandler(hclog.NewNullLogger(), registry, secrets) + store := storeMocks.NewMockStore(ctrl) + handler := NewRequestHandler(hclog.NewNullLogger(), store, secrets) // errors on non secret requests err := handler.OnStreamOpen(context.Background(), 1, resource.ClusterType) diff --git a/internal/envoy/middleware.go b/internal/envoy/middleware.go index b73e7b3b4..4a15aaf8d 100644 --- a/internal/envoy/middleware.go +++ b/internal/envoy/middleware.go @@ -11,8 +11,10 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - "github.com/hashicorp/consul-api-gateway/internal/core" "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul-api-gateway/internal/core" + "github.com/hashicorp/consul-api-gateway/internal/store" ) //go:generate mockgen -source ./middleware.go -destination ./mocks/middleware.go -package mocks GatewaySecretRegistry @@ -48,30 +50,20 @@ func GatewayFromContext(ctx context.Context) core.GatewayID { return value.(core.GatewayID) } -// GatewaySecretRegistry is used as the authority for determining what gateways the SDS server -// should actually respond to because they're managed by consul-api-gateway -type GatewaySecretRegistry interface { - // GatewayExists is used to determine whether or not we know a particular gateway instance - GatewayExists(ctx context.Context, info core.GatewayID) (bool, error) - // CanFetchSecrets is used to determine whether a gateway should be able to fetch a set - // of secrets it has requested - CanFetchSecrets(ctx context.Context, info core.GatewayID, secrets []string) (bool, error) -} - // SPIFFEStreamMiddleware verifies the spiffe entries for the certificate // and sets the client identidy on the request context. If no // spiffe information is detected, or if the service is unknown, // the request is rejected. -func SPIFFEStreamMiddleware(logger hclog.Logger, fetcher CertificateFetcher, registry GatewaySecretRegistry) grpc.StreamServerInterceptor { +func SPIFFEStreamMiddleware(logger hclog.Logger, fetcher CertificateFetcher, store store.Store) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if info, ok := verifySPIFFE(ss.Context(), logger, registry); ok { + if info, ok := verifySPIFFE(ss.Context(), logger, store); ok { return handler(srv, wrapStream(ss, info)) } return status.Errorf(codes.Unauthenticated, "unable to authenticate request") } } -func verifySPIFFE(ctx context.Context, logger hclog.Logger, registry GatewaySecretRegistry) (core.GatewayID, bool) { +func verifySPIFFE(ctx context.Context, logger hclog.Logger, store store.Store) (core.GatewayID, bool) { if p, ok := peer.FromContext(ctx); ok { if mtls, ok := p.AuthInfo.(credentials.TLSInfo); ok { // grab the peer certificate info @@ -90,12 +82,12 @@ func verifySPIFFE(ctx context.Context, logger hclog.Logger, registry GatewaySecr continue } // if we're tracking the gateway then we're good - exists, err := registry.GatewayExists(ctx, info) + gateway, err := store.GetGateway(ctx, info) if err != nil { logger.Error("error checking for gateway, skipping", "error", err) continue } - if exists { + if gateway != nil { return info, true } logger.Warn("gateway not found", "namespace", info.ConsulNamespace, "service", info.Service) diff --git a/internal/envoy/mocks/middleware.go b/internal/envoy/mocks/middleware.go index f4c820a96..0959a044e 100644 --- a/internal/envoy/mocks/middleware.go +++ b/internal/envoy/mocks/middleware.go @@ -3,64 +3,3 @@ // Package mocks is a generated GoMock package. package mocks - -import ( - context "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - core "github.com/hashicorp/consul-api-gateway/internal/core" -) - -// MockGatewaySecretRegistry is a mock of GatewaySecretRegistry interface. -type MockGatewaySecretRegistry struct { - ctrl *gomock.Controller - recorder *MockGatewaySecretRegistryMockRecorder -} - -// MockGatewaySecretRegistryMockRecorder is the mock recorder for MockGatewaySecretRegistry. -type MockGatewaySecretRegistryMockRecorder struct { - mock *MockGatewaySecretRegistry -} - -// NewMockGatewaySecretRegistry creates a new mock instance. -func NewMockGatewaySecretRegistry(ctrl *gomock.Controller) *MockGatewaySecretRegistry { - mock := &MockGatewaySecretRegistry{ctrl: ctrl} - mock.recorder = &MockGatewaySecretRegistryMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockGatewaySecretRegistry) EXPECT() *MockGatewaySecretRegistryMockRecorder { - return m.recorder -} - -// CanFetchSecrets mocks base method. -func (m *MockGatewaySecretRegistry) CanFetchSecrets(ctx context.Context, info core.GatewayID, secrets []string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanFetchSecrets", ctx, info, secrets) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CanFetchSecrets indicates an expected call of CanFetchSecrets. -func (mr *MockGatewaySecretRegistryMockRecorder) CanFetchSecrets(ctx, info, secrets interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanFetchSecrets", reflect.TypeOf((*MockGatewaySecretRegistry)(nil).CanFetchSecrets), ctx, info, secrets) -} - -// GatewayExists mocks base method. -func (m *MockGatewaySecretRegistry) GatewayExists(ctx context.Context, info core.GatewayID) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GatewayExists", ctx, info) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GatewayExists indicates an expected call of GatewayExists. -func (mr *MockGatewaySecretRegistryMockRecorder) GatewayExists(ctx, info interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GatewayExists", reflect.TypeOf((*MockGatewaySecretRegistry)(nil).GatewayExists), ctx, info) -} diff --git a/internal/envoy/sds.go b/internal/envoy/sds.go index 911df58e7..c17f69b52 100644 --- a/internal/envoy/sds.go +++ b/internal/envoy/sds.go @@ -23,6 +23,7 @@ import ( grpcint "github.com/hashicorp/consul-api-gateway/internal/grpc" "github.com/hashicorp/consul-api-gateway/internal/metrics" + "github.com/hashicorp/consul-api-gateway/internal/store" ) //go:generate mockgen -source ./sds.go -destination ./mocks/sds.go -package mocks CertificateFetcher @@ -52,19 +53,19 @@ type SDSServer struct { client SecretClient bindAddress string protocol string - gatewayRegistry GatewaySecretRegistry + store store.Store certificateForcePullInterval time.Duration } // NEWSDSServer initializes an SDSServer instance -func NewSDSServer(logger hclog.Logger, fetcher CertificateFetcher, client SecretClient, registry GatewaySecretRegistry) *SDSServer { +func NewSDSServer(logger hclog.Logger, fetcher CertificateFetcher, client SecretClient, store store.Store) *SDSServer { return &SDSServer{ logger: logger, fetcher: fetcher, client: client, bindAddress: defaultGRPCBindAddress, protocol: "tcp", - gatewayRegistry: registry, + store: store, certificateForcePullInterval: defaultCertificateForcePullInterval, } } @@ -98,13 +99,13 @@ func (s *SDSServer) Run(ctx context.Context) error { }, ClientAuth: tls.RequireAndVerifyClientCert, })), - grpc.StreamInterceptor(SPIFFEStreamMiddleware(s.logger, s.fetcher, s.gatewayRegistry)), + grpc.StreamInterceptor(SPIFFEStreamMiddleware(s.logger, s.fetcher, s.store)), } s.server = grpc.NewServer(opts...) resourceCache := cache.NewLinearCache(resource.SecretType, cache.WithLogger(wrapEnvoyLogger(s.logger.Named("cache")))) secretManager := NewSecretManager(s.client, resourceCache, s.logger.Named("secret-manager")) - handler := NewRequestHandler(s.logger.Named("handler"), s.gatewayRegistry, secretManager) + handler := NewRequestHandler(s.logger.Named("handler"), s.store, secretManager) sdsServer := server.NewServer(childCtx, resourceCache, handler) secretservice.RegisterSecretDiscoveryServiceServer(s.server, sdsServer) listener, err := net.Listen(s.protocol, s.bindAddress) diff --git a/internal/envoy/sds_test.go b/internal/envoy/sds_test.go index 44fdafef8..d088faca1 100644 --- a/internal/envoy/sds_test.go +++ b/internal/envoy/sds_test.go @@ -25,10 +25,13 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/status" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul-api-gateway/internal/envoy/mocks" + "github.com/hashicorp/consul-api-gateway/internal/store" "github.com/hashicorp/consul-api-gateway/internal/store/memory" + storeMocks "github.com/hashicorp/consul-api-gateway/internal/store/mocks" gwTesting "github.com/hashicorp/consul-api-gateway/internal/testing" - "github.com/hashicorp/go-hclog" ) func TestSDSRunCertificateVerification(t *testing.T) { @@ -36,11 +39,12 @@ func TestSDSRunCertificateVerification(t *testing.T) { ca, server, client := gwTesting.DefaultCertificates() - err := runTestServer(t, ca.CertBytes, func(ctrl *gomock.Controller) GatewaySecretRegistry { - gatewayRegistry := mocks.NewMockGatewaySecretRegistry(ctrl) - gatewayRegistry.EXPECT().GatewayExists(gomock.Any(), gomock.Any()).MinTimes(1).Return(true, nil) - gatewayRegistry.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any(), gomock.Any()).MinTimes(1).Return(true, nil) - return gatewayRegistry + err := runTestServer(t, ca.CertBytes, func(ctrl *gomock.Controller) store.Store { + store := storeMocks.NewMockStore(ctrl) + gateway := storeMocks.NewMockGateway(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).MinTimes(1).Return(gateway, nil) + gateway.EXPECT().CanFetchSecrets(gomock.Any(), gomock.Any()).MinTimes(1).Return(true, nil) + return store }, func(serverAddress string, fetcher *mocks.MockCertificateFetcher) { fetcher.EXPECT().TLSCertificate().MinTimes(1).Return(&server.X509) @@ -171,10 +175,10 @@ func TestSDSSPIFFENoMatchingGateway(t *testing.T) { ca, server, client := gwTesting.DefaultCertificates() - err := runTestServer(t, ca.CertBytes, func(ctrl *gomock.Controller) GatewaySecretRegistry { - gatewayRegistry := mocks.NewMockGatewaySecretRegistry(ctrl) - gatewayRegistry.EXPECT().GatewayExists(gomock.Any(), gomock.Any()).MinTimes(1).Return(false, nil) - return gatewayRegistry + err := runTestServer(t, ca.CertBytes, func(ctrl *gomock.Controller) store.Store { + store := storeMocks.NewMockStore(ctrl) + store.EXPECT().GetGateway(gomock.Any(), gomock.Any()).MinTimes(1).Return(nil, nil) + return store }, func(serverAddress string, fetcher *mocks.MockCertificateFetcher) { fetcher.EXPECT().TLSCertificate().Return(&server.X509) err := testClientSDS(t, serverAddress, client, ca.CertBytes) @@ -225,7 +229,7 @@ func testClientSDS(t *testing.T, address string, cert *gwTesting.CertificateInfo }) } -func runTestServer(t *testing.T, ca []byte, registryFn func(*gomock.Controller) GatewaySecretRegistry, callback func(serverAddress string, fetcher *mocks.MockCertificateFetcher)) error { +func runTestServer(t *testing.T, ca []byte, registryFn func(*gomock.Controller) store.Store, callback func(serverAddress string, fetcher *mocks.MockCertificateFetcher)) error { t.Helper() ctx, cancel := context.WithCancel(context.Background()) @@ -259,7 +263,7 @@ func runTestServer(t *testing.T, ca []byte, registryFn func(*gomock.Controller) sds.bindAddress = serverAddress sds.protocol = "unix" if registryFn != nil { - sds.gatewayRegistry = registryFn(ctrl) + sds.store = registryFn(ctrl) } errEarlyTestTermination := errors.New("early termination") diff --git a/internal/k8s/reconciler/binder.go b/internal/k8s/reconciler/binder.go new file mode 100644 index 000000000..37d2abbec --- /dev/null +++ b/internal/k8s/reconciler/binder.go @@ -0,0 +1,208 @@ +package reconciler + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + klabels "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" + gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" + gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + + "github.com/hashicorp/consul-api-gateway/internal/k8s/gatewayclient" + "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/common" + "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/errors" + "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/state" + "github.com/hashicorp/consul-api-gateway/internal/k8s/utils" +) + +const ( + // NamespaceNameLabel represents that label added automatically to namespaces is newer Kubernetes clusters + NamespaceNameLabel = "kubernetes.io/metadata.name" +) + +// binder wraps a Gateway and the corresponding GatewayState and encapsulates +// the logic for binding new routes to that Gateway. +type binder struct { + Client gatewayclient.Client + Gateway *gwv1beta1.Gateway + GatewayState *state.GatewayState +} + +func newBinder(client gatewayclient.Client, gateway *gwv1beta1.Gateway, state *state.GatewayState) *binder { + return &binder{ + Client: client, + Gateway: gateway, + GatewayState: state, + } +} + +// Bind will attempt to bind the provided route to all listeners on the Gateway and +// remove the route from any listeners that the route should no longer be bound to. +// The latter is important for scenarios such as the route's parent changing. +func (b *binder) Bind(ctx context.Context, route *K8sRoute) []string { + var boundListeners []string + + // If the route doesn't reference this Gateway, remove the route + // from any listeners that it may have previously bound to + if !b.routeReferencesThisGateway(route) { + for _, listenerState := range b.GatewayState.Listeners { + delete(listenerState.Routes, route.ID()) + } + return boundListeners + } + + // The route does reference this Gateway, so attempt to bind to each listener + for _, ref := range route.CommonRouteSpec().ParentRefs { + for i, listener := range b.Gateway.Spec.Listeners { + listenerState := b.GatewayState.Listeners[i] + if b.canBind(ctx, listener, listenerState, ref, route) { + listenerState.Routes[route.ID()] = route.resolve(b.GatewayState.ConsulNamespace, b.Gateway, listener) + boundListeners = append(boundListeners, string(listener.Name)) + } else { + // If the route cannot bind to this listener, remove the route + // in case it was previously bound + delete(listenerState.Routes, route.ID()) + } + } + } + + return boundListeners +} + +func (b *binder) routeReferencesThisGateway(route *K8sRoute) bool { + thisGateway := utils.NamespacedName(b.Gateway) + for _, ref := range route.CommonRouteSpec().ParentRefs { + gatewayReferenced, isGatewayTypeRef := utils.ReferencesGateway(route.GetNamespace(), ref) + if isGatewayTypeRef && gatewayReferenced == thisGateway { + return true + } + } + return false +} + +func (b *binder) canBind(ctx context.Context, listener gwv1beta1.Listener, state *state.ListenerState, ref gwv1alpha2.ParentReference, route *K8sRoute) bool { + if state.Status.Ready.HasError() { + return false + } + + // must is only true if there's a ref with a specific listener name + // meaning if we must attach, but cannot, it's an error + allowed, must := routeMatchesListener(listener.Name, ref.SectionName) + if !allowed { + return false + } + + if !routeKindIsAllowedForListener(common.SupportedKindsFor(listener.Protocol), route) { + if must { + route.RouteState.BindFailed(errors.NewBindErrorRouteKind("route kind not allowed for listener"), ref) + } + return false + } + + allowed, err := routeAllowedForListenerNamespaces(ctx, b.Gateway.Namespace, listener.AllowedRoutes, route, b.Client) + if err != nil { + route.RouteState.BindFailed(fmt.Errorf("error checking listener namespaces: %w", err), ref) + return false + } + if !allowed { + if must { + route.RouteState.BindFailed(errors.NewBindErrorListenerNamespacePolicy("route not allowed because of listener namespace policy"), ref) + } + return false + } + + if !route.matchesHostname(listener.Hostname) { + if must { + route.RouteState.BindFailed(errors.NewBindErrorHostnameMismatch("route does not match listener hostname"), ref) + } + return false + } + + // check if the route is valid, if not, then return a status about it being rejected + if !route.RouteState.ResolutionErrors.Empty() { + route.RouteState.BindFailed(errors.NewBindErrorRouteInvalid("route is in an invalid state and cannot bind"), ref) + return false + } + + route.RouteState.Bound(ref) + return true +} + +// routeAllowedForListenerNamespaces determines whether the route is allowed +// to bind to the Gateway based on the AllowedRoutes namespace selectors. +func routeAllowedForListenerNamespaces(ctx context.Context, gatewayNS string, allowedRoutes *gwv1beta1.AllowedRoutes, route *K8sRoute, c gatewayclient.Client) (bool, error) { + var namespaceSelector *gwv1beta1.RouteNamespaces + if allowedRoutes != nil { + // check gateway namespace + namespaceSelector = allowedRoutes.Namespaces + } + + // set default if namespace selector is nil + from := gwv1beta1.NamespacesFromSame + if namespaceSelector != nil && namespaceSelector.From != nil && *namespaceSelector.From != "" { + from = *namespaceSelector.From + } + switch from { + case gwv1beta1.NamespacesFromAll: + return true, nil + case gwv1beta1.NamespacesFromSame: + return gatewayNS == route.GetNamespace(), nil + case gwv1beta1.NamespacesFromSelector: + namespaceSelector, err := metav1.LabelSelectorAsSelector(namespaceSelector.Selector) + if err != nil { + return false, fmt.Errorf("error parsing label selector: %w", err) + } + + // retrieve the route's namespace and determine whether selector matches + namespace, err := c.GetNamespace(ctx, types.NamespacedName{Name: route.GetNamespace()}) + if err != nil { + return false, fmt.Errorf("error retrieving namespace for route: %w", err) + } + + return namespaceSelector.Matches(toNamespaceSet(namespace.GetName(), namespace.GetLabels())), nil + } + return false, nil +} + +func routeKindIsAllowedForListener(kinds []gwv1beta1.RouteGroupKind, route *K8sRoute) bool { + if kinds == nil { + return true + } + + gvk := route.GroupVersionKind() + for _, kind := range kinds { + group := gwv1beta1.GroupName + if kind.Group != nil && *kind.Group != "" { + group = string(*kind.Group) + } + if string(kind.Kind) == gvk.Kind && group == gvk.Group { + return true + } + } + + return false +} + +func toNamespaceSet(name string, labels map[string]string) klabels.Labels { + // If namespace label is not set, implicitly insert it to support older Kubernetes versions + if labels[NamespaceNameLabel] == name { + // Already set, avoid copies + return klabels.Set(labels) + } + // First we need a copy to not modify the underlying object + ret := make(map[string]string, len(labels)+1) + for k, v := range labels { + ret[k] = v + } + ret[NamespaceNameLabel] = name + return klabels.Set(ret) +} + +func routeMatchesListener(listenerName gwv1beta1.SectionName, routeSectionName *gwv1alpha2.SectionName) (can bool, must bool) { + if routeSectionName == nil { + return true, false + } + return string(listenerName) == string(*routeSectionName), true +} diff --git a/internal/k8s/reconciler/binder_test.go b/internal/k8s/reconciler/binder_test.go new file mode 100644 index 000000000..5cbc9a77c --- /dev/null +++ b/internal/k8s/reconciler/binder_test.go @@ -0,0 +1,584 @@ +package reconciler + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + core "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" + gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + + "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul-api-gateway/internal/k8s/gatewayclient/mocks" + "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/state" +) + +func TestBinder(t *testing.T) { + t.Parallel() + + same := gwv1beta1.NamespacesFromSame + selector := gwv1beta1.NamespacesFromSelector + other := gwv1alpha2.Namespace("other") + routeMeta := meta.TypeMeta{} + routeMeta.SetGroupVersionKind(schema.GroupVersionKind{ + Group: gwv1alpha2.GroupVersion.Group, + Version: gwv1alpha2.GroupVersion.Version, + Kind: "HTTPRoute", + }) + udpMeta := meta.TypeMeta{} + udpMeta.SetGroupVersionKind(schema.GroupVersionKind{ + Group: gwv1alpha2.GroupVersion.Group, + Version: gwv1alpha2.GroupVersion.Version, + Kind: "UDPRoute", + }) + + for _, test := range []struct { + name string + gateway *gwv1beta1.Gateway + namespace *core.Namespace + listenerError error + route Route + didBind bool + }{ + { + name: "no match", + gateway: &gwv1beta1.Gateway{ + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{}}, + }, + }, + route: &gwv1alpha2.HTTPRoute{}, + didBind: false, + }, + { + name: "match", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{}}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + }}, + }, + }, + }, + didBind: true, + }, + { + name: "bad route type", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Protocol: gwv1beta1.HTTPProtocolType, + }}, + }, + }, + route: &gwv1alpha2.UDPRoute{ + TypeMeta: udpMeta, + Spec: gwv1alpha2.UDPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + }}, + }, + }, + }, + didBind: false, + }, + { + name: "good route type", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Protocol: gwv1beta1.HTTPProtocolType, + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + }}, + }, + }, + }, + didBind: true, + }, + { + name: "listener not ready", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{}}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + }}, + }, + }, + }, + listenerError: errors.New("invalid"), + didBind: false, + }, + { + name: "not allowed namespace", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + Namespace: "other", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Name: gwv1beta1.SectionName("listener"), + Protocol: gwv1beta1.HTTPProtocolType, + AllowedRoutes: &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &same, + }, + }, + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + Namespace: &other, + SectionName: sectionNamePtr("listener"), + }}, + }, + }, + }, + didBind: false, + }, + { + name: "allowed namespace", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + Namespace: "other", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Name: gwv1beta1.SectionName("listener"), + Protocol: gwv1beta1.HTTPProtocolType, + AllowedRoutes: &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &same, + }, + }, + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + ObjectMeta: meta.ObjectMeta{ + Namespace: "other", + }, + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + Namespace: &other, + SectionName: sectionNamePtr("listener"), + }}, + }, + }, + }, + didBind: true, + }, + { + name: "not allowed namespace match", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + Namespace: "other", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Name: gwv1beta1.SectionName("listener"), + Protocol: gwv1beta1.HTTPProtocolType, + AllowedRoutes: &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &selector, + Selector: &meta.LabelSelector{ + MatchExpressions: []meta.LabelSelectorRequirement{{ + Key: "test", + Operator: meta.LabelSelectorOpIn, + Values: []string{"foo"}, + }}, + }, + }, + }, + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + ObjectMeta: meta.ObjectMeta{ + Namespace: "other", + }, + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + Namespace: &other, + SectionName: sectionNamePtr("listener"), + }}, + }, + }, + }, + namespace: &core.Namespace{ + ObjectMeta: meta.ObjectMeta{ + Labels: map[string]string{ + "test": "bar", + }, + }, + }, + didBind: false, + }, + { + name: "allowed namespace match", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + Namespace: "other", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Name: gwv1beta1.SectionName("listener"), + Protocol: gwv1beta1.HTTPProtocolType, + AllowedRoutes: &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &selector, + Selector: &meta.LabelSelector{ + MatchExpressions: []meta.LabelSelectorRequirement{{ + Key: "test", + Operator: meta.LabelSelectorOpIn, + Values: []string{"foo"}, + }}, + }, + }, + }, + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + ObjectMeta: meta.ObjectMeta{ + Namespace: "other", + }, + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + Namespace: &other, + SectionName: sectionNamePtr("listener"), + }}, + }, + }, + }, + namespace: &core.Namespace{ + ObjectMeta: meta.ObjectMeta{ + Labels: map[string]string{ + "test": "foo", + }, + }, + }, + didBind: true, + }, + { + name: "hostname no match", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Name: gwv1beta1.SectionName("listener"), + Hostname: hostnamePtr("host"), + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + SectionName: sectionNamePtr("listener"), + }}, + }, + Hostnames: []gwv1alpha2.Hostname{"other"}, + }, + }, + didBind: false, + }, + { + name: "hostname match", + gateway: &gwv1beta1.Gateway{ + ObjectMeta: meta.ObjectMeta{ + Name: "gateway", + }, + Spec: gwv1beta1.GatewaySpec{ + Listeners: []gwv1beta1.Listener{{ + Name: gwv1beta1.SectionName("listener"), + Hostname: hostnamePtr("host"), + }}, + }, + }, + route: &gwv1alpha2.HTTPRoute{ + Spec: gwv1alpha2.HTTPRouteSpec{ + CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ + ParentRefs: []gwv1alpha2.ParentReference{{ + Name: "gateway", + SectionName: sectionNamePtr("listener"), + }}, + }, + Hostnames: []gwv1alpha2.Hostname{"other", "host"}, + }, + }, + didBind: true, + }, + } { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := mocks.NewMockClient(ctrl) + + factory := NewFactory(FactoryConfig{ + Logger: hclog.NewNullLogger(), + }) + gatewayState := state.InitialGatewayState(test.gateway) + if test.listenerError != nil { + gatewayState.Listeners[0].Status.Ready.Invalid = test.listenerError + } + if test.namespace != nil { + client.EXPECT().GetNamespace(gomock.Any(), gomock.Any()).Return(test.namespace, nil) + } + + binder := newBinder(client, test.gateway, gatewayState) + listeners := binder.Bind(context.Background(), factory.NewRoute(test.route)) + if test.didBind { + require.NotEmpty(t, listeners) + } else { + require.Empty(t, listeners) + } + }) + } +} + +func TestRouteAllowedForListenerNamespaces(t *testing.T) { + t.Parallel() + + factory := NewFactory(FactoryConfig{ + Logger: hclog.NewNullLogger(), + }) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := mocks.NewMockClient(ctrl) + + // same + same := gwv1beta1.NamespacesFromSame + + allowed, err := routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &same, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "expected", + }, + }), client) + require.NoError(t, err) + require.True(t, allowed) + + allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &same, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "other", + }, + }), client) + require.NoError(t, err) + require.False(t, allowed) + + // all + all := gwv1beta1.NamespacesFromAll + allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &all, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "other", + }, + }), client) + require.NoError(t, err) + require.True(t, allowed) + + // selector + selector := gwv1beta1.NamespacesFromSelector + + matchingNamespace := &core.Namespace{ + ObjectMeta: meta.ObjectMeta{ + Labels: map[string]string{ + "label": "test", + "kubernetes.io/metadata.name": "expected", + }}} + invalidNamespace := &core.Namespace{ObjectMeta: meta.ObjectMeta{Labels: map[string]string{}}} + + client.EXPECT().GetNamespace(context.Background(), types.NamespacedName{Name: "expected"}).Return(invalidNamespace, nil).Times(1) + allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &selector, + Selector: &meta.LabelSelector{ + MatchLabels: map[string]string{ + "label": "test", + }, + }, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "expected", + }, + }), client) + require.NoError(t, err) + require.False(t, allowed) + + client.EXPECT().GetNamespace(context.Background(), types.NamespacedName{Name: "expected"}).Return(matchingNamespace, nil).Times(1) + allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &selector, + Selector: &meta.LabelSelector{ + MatchLabels: map[string]string{ + "label": "test", + }, + }, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "expected", + }, + }), client) + require.NoError(t, err) + require.True(t, allowed) + + _, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &selector, + Selector: &meta.LabelSelector{ + MatchExpressions: []meta.LabelSelectorRequirement{{ + Key: "test", + Operator: meta.LabelSelectorOperator("invalid"), + }}, + }, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "expected", + }, + }), client) + require.Error(t, err) + + // unknown + unknown := gwv1beta1.FromNamespaces("unknown") + allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ + Namespaces: &gwv1beta1.RouteNamespaces{ + From: &unknown, + }, + }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + ObjectMeta: meta.ObjectMeta{ + Namespace: "expected", + }, + }), client) + require.NoError(t, err) + require.False(t, allowed) +} + +func TestRouteKindIsAllowedForListener(t *testing.T) { + t.Parallel() + + factory := NewFactory(FactoryConfig{ + Logger: hclog.NewNullLogger(), + }) + + routeMeta := meta.TypeMeta{} + routeMeta.SetGroupVersionKind(schema.GroupVersionKind{ + Group: gwv1alpha2.GroupVersion.Group, + Version: gwv1alpha2.GroupVersion.Version, + Kind: "HTTPRoute", + }) + require.True(t, routeKindIsAllowedForListener([]gwv1beta1.RouteGroupKind{{ + Group: (*gwv1beta1.Group)(&gwv1alpha2.GroupVersion.Group), + Kind: "HTTPRoute", + }}, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + }))) + require.False(t, routeKindIsAllowedForListener([]gwv1beta1.RouteGroupKind{{ + Group: (*gwv1beta1.Group)(&gwv1alpha2.GroupVersion.Group), + Kind: "TCPRoute", + }}, factory.NewRoute(&gwv1alpha2.HTTPRoute{ + TypeMeta: routeMeta, + }))) +} + +func TestRouteMatchesListener(t *testing.T) { + t.Parallel() + + name := gwv1alpha2.SectionName("name") + can, must := routeMatchesListener("name", &name) + assert.True(t, can) + assert.True(t, must) + + can, must = routeMatchesListener("name", nil) + assert.True(t, can) + assert.False(t, must) + + can, must = routeMatchesListener("other", &name) + assert.False(t, can) + assert.True(t, must) +} + +func sectionNamePtr(name string) *gwv1alpha2.SectionName { + value := gwv1alpha2.SectionName(name) + return &value +} + +func hostnamePtr(name string) *gwv1beta1.Hostname { + value := gwv1beta1.Hostname(name) + return &value +} diff --git a/internal/k8s/reconciler/gateway.go b/internal/k8s/reconciler/gateway.go index 49ccb25aa..9a81bf43e 100644 --- a/internal/k8s/reconciler/gateway.go +++ b/internal/k8s/reconciler/gateway.go @@ -92,47 +92,80 @@ func (g *K8sGateway) Meta() map[string]string { } } -func (g *K8sGateway) Listeners() []store.Listener { - listeners := []store.Listener{} - - for _, listener := range g.listeners { - listeners = append(listeners, listener) +// Bind returns the name of the listeners to which a route bound +func (g *K8sGateway) Bind(ctx context.Context, route store.Route) []string { + k8sRoute, ok := route.(*K8sRoute) + if !ok { + return nil } - return listeners + return newBinder(g.client, g.Gateway, g.GatewayState).Bind(ctx, k8sRoute) } -func (g *K8sGateway) ShouldUpdate(other store.Gateway) bool { - if other == nil { - return false +func (g *K8sGateway) Remove(ctx context.Context, routeID string) error { + for _, listener := range g.GatewayState.Listeners { + delete(listener.Routes, routeID) } - if g == nil { - return true + return nil +} + +func (g *K8sGateway) Resolve() core.ResolvedGateway { + listeners := []core.ResolvedListener{} + for i, listener := range g.Gateway.Spec.Listeners { + state := g.GatewayState.Listeners[i] + if state.Valid() { + listeners = append(listeners, g.resolveListener(state, listener)) + } + } + return core.ResolvedGateway{ + ID: g.ID(), + Meta: g.Meta(), + Listeners: listeners, } +} - otherGateway, ok := other.(*K8sGateway) - if !ok { - return false +func (g *K8sGateway) resolveListener(state *state.ListenerState, listener gwv1beta1.Listener) core.ResolvedListener { + routes := []core.ResolvedRoute{} + for _, route := range state.Routes { + routes = append(routes, route) + } + protocol, _ := utils.ProtocolToConsul(state.Protocol) + + return core.ResolvedListener{ + Name: listenerName(listener), + Hostname: listenerHostname(listener), + Port: int(listener.Port), + Protocol: protocol, + TLS: state.TLS, + Routes: routes, } - return !utils.ResourceVersionGreater(g.Gateway.ResourceVersion, otherGateway.Gateway.ResourceVersion) } -func (g *K8sGateway) ShouldBind(route store.Route) bool { - k8sRoute, ok := route.(*K8sRoute) - if !ok { - return false +func (g *K8sGateway) CanFetchSecrets(_ context.Context, secrets []string) (bool, error) { + certificates := make(map[string]struct{}) + for _, listener := range g.GatewayState.Listeners { + for _, cert := range listener.TLS.Certificates { + certificates[cert] = struct{}{} + } } - - for _, ref := range k8sRoute.CommonRouteSpec().ParentRefs { - if namespacedName, isGateway := utils.ReferencesGateway(k8sRoute.GetNamespace(), ref); isGateway { - if utils.NamespacedName(g.Gateway) == namespacedName { - return true - } + for _, secret := range secrets { + if _, found := certificates[secret]; !found { + return false, nil } } - return false + return true, nil +} + +func (g *K8sGateway) Listeners() []store.Listener { + listeners := []store.Listener{} + + for _, listener := range g.listeners { + listeners = append(listeners, listener) + } + + return listeners } func (g *K8sGateway) TrackSync(ctx context.Context, sync func() (bool, error)) error { @@ -165,3 +198,17 @@ func (g *K8sGateway) TrackSync(ctx context.Context, sync func() (bool, error)) e } return nil } + +func listenerHostname(listener gwv1beta1.Listener) string { + if listener.Hostname != nil { + return string(*listener.Hostname) + } + return "" +} + +func listenerName(listener gwv1beta1.Listener) string { + if listener.Name != "" { + return string(listener.Name) + } + return defaultListenerName +} diff --git a/internal/k8s/reconciler/gateway_test.go b/internal/k8s/reconciler/gateway_test.go index 61a9c736d..e58527160 100644 --- a/internal/k8s/reconciler/gateway_test.go +++ b/internal/k8s/reconciler/gateway_test.go @@ -6,11 +6,9 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" meta "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/utils/pointer" - gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" "github.com/hashicorp/go-hclog" @@ -18,8 +16,6 @@ import ( internalCore "github.com/hashicorp/consul-api-gateway/internal/core" "github.com/hashicorp/consul-api-gateway/internal/k8s/gatewayclient/mocks" "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/state" - "github.com/hashicorp/consul-api-gateway/internal/k8s/service" - storeMocks "github.com/hashicorp/consul-api-gateway/internal/store/mocks" apigwv1alpha1 "github.com/hashicorp/consul-api-gateway/pkg/apis/v1alpha1" ) @@ -194,94 +190,3 @@ func TestGatewayTrackSync(t *testing.T) { return false, expected })) } - -func TestGatewayShouldUpdate(t *testing.T) { - t.Parallel() - - factory := NewFactory(FactoryConfig{ - Logger: hclog.NewNullLogger(), - }) - - gw := &gwv1beta1.Gateway{} - gateway := factory.NewGateway(NewGatewayConfig{ - Gateway: gw, - ConsulNamespace: "consul", - }) - - otherGW := &gwv1beta1.Gateway{} - other := factory.NewGateway(NewGatewayConfig{ - Gateway: otherGW, - ConsulNamespace: "consul", - }) - - // Have equal resource version - gateway.Gateway.ObjectMeta.ResourceVersion = `0` - other.Gateway.ObjectMeta.ResourceVersion = `0` - assert.True(t, gateway.ShouldUpdate(other)) - - // Have greater resource version - gateway.Gateway.ObjectMeta.ResourceVersion = `1` - other.Gateway.ObjectMeta.ResourceVersion = `0` - assert.False(t, gateway.ShouldUpdate(other)) - - // Have lesser resource version - gateway.Gateway.ObjectMeta.ResourceVersion = `0` - other.Gateway.ObjectMeta.ResourceVersion = `1` - assert.True(t, gateway.ShouldUpdate(other)) - - // Have non-numeric resource version - gateway.Gateway.ObjectMeta.ResourceVersion = `a` - other.Gateway.ObjectMeta.ResourceVersion = `0` - assert.True(t, gateway.ShouldUpdate(other)) - - // Other gateway non-numeric resource version - gateway.Gateway.ObjectMeta.ResourceVersion = `0` - other.Gateway.ObjectMeta.ResourceVersion = `a` - assert.False(t, gateway.ShouldUpdate(other)) - - // Other gateway nil - assert.False(t, gateway.ShouldUpdate(nil)) - - // Have nil gateway - gateway = nil - assert.True(t, gateway.ShouldUpdate(other)) -} - -func TestGatewayShouldBind(t *testing.T) { - t.Parallel() - - factory := NewFactory(FactoryConfig{ - Logger: hclog.NewNullLogger(), - }) - - gw := &gwv1beta1.Gateway{} - gateway := factory.NewGateway(NewGatewayConfig{ - Gateway: gw, - ConsulNamespace: "consul", - }) - gateway.Gateway.Name = "name" - - require.False(t, gateway.ShouldBind(storeMocks.NewMockRoute(nil))) - - route := newK8sRoute(&gwv1alpha2.HTTPRoute{}, K8sRouteConfig{ - Logger: hclog.NewNullLogger(), - }) - route.RouteState.ResolutionErrors.Add(service.NewConsulResolutionError("test")) - require.False(t, gateway.ShouldBind(route)) - - require.True(t, gateway.ShouldBind(newK8sRoute(&gwv1alpha2.HTTPRoute{ - Spec: gwv1alpha2.HTTPRouteSpec{ - CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ - ParentRefs: []gwv1alpha2.ParentReference{{ - Name: "name", - }}, - }, - }, - }, K8sRouteConfig{ - Logger: hclog.NewNullLogger(), - }))) - - require.False(t, gateway.ShouldBind(newK8sRoute(&gwv1alpha2.HTTPRoute{}, K8sRouteConfig{ - Logger: hclog.NewNullLogger(), - }))) -} diff --git a/internal/k8s/reconciler/listener.go b/internal/k8s/reconciler/listener.go index 58b955e48..96aa49304 100644 --- a/internal/k8s/reconciler/listener.go +++ b/internal/k8s/reconciler/listener.go @@ -195,7 +195,7 @@ func (l *K8sListener) canBind(ctx context.Context, ref gwv1alpha2.ParentReferenc return false, nil } - if !route.MatchesHostname(l.listener.Hostname) { + if !route.matchesHostname(l.listener.Hostname) { l.logger.Trace("route does not match listener hostname", "route", route.ID()) if must { return false, rerrors.NewBindErrorHostnameMismatch("route does not match listener hostname") @@ -216,7 +216,10 @@ func (l *K8sListener) canBind(ctx context.Context, ref gwv1alpha2.ParentReferenc func (l *K8sListener) OnRouteAdded(route store.Route) { atomic.AddInt32(&l.routeCount, 1) - l.ListenerState.Routes[route.ID()] = route.Resolve(l) + + if k8sRoute, ok := route.(*K8sRoute); ok { + l.ListenerState.Routes[route.ID()] = k8sRoute.resolve(l.consulNamespace, l.gateway, l.listener) + } } func (l *K8sListener) OnRouteRemoved(routeID string) { diff --git a/internal/k8s/reconciler/route.go b/internal/k8s/reconciler/route.go index 20a07b73e..cb899a965 100644 --- a/internal/k8s/reconciler/route.go +++ b/internal/k8s/reconciler/route.go @@ -58,16 +58,6 @@ func newK8sRoute(route Route, config K8sRouteConfig) *K8sRoute { } } -func (r *K8sRoute) parentKeyForGateway(parent types.NamespacedName) (string, bool) { - for _, p := range r.Parents() { - gatewayName, isGateway := utils.ReferencesGateway(r.GetNamespace(), p) - if isGateway && gatewayName == parent { - return asJSON(p), true - } - } - return "", false -} - func (r *K8sRoute) ID() string { switch r.Route.(type) { case *gwv1alpha2.HTTPRoute: @@ -78,7 +68,7 @@ func (r *K8sRoute) ID() string { return "" } -func (r *K8sRoute) MatchesHostname(hostname *gwv1beta1.Hostname) bool { +func (r *K8sRoute) matchesHostname(hostname *gwv1beta1.Hostname) bool { switch route := r.Route.(type) { case *gwv1alpha2.HTTPRoute: return routeMatchesListenerHostname(hostname, route.Spec.Hostnames) @@ -142,19 +132,22 @@ func (r *K8sRoute) Resolve(listener store.Listener) core.ResolvedRoute { return nil } - prefix := fmt.Sprintf("consul-api-gateway_%s_", k8sListener.gateway.Name) - namespace := k8sListener.consulNamespace - hostname := k8sListener.Config().Hostname + return r.resolve(k8sListener.consulNamespace, k8sListener.gateway, k8sListener.listener) +} + +func (r *K8sRoute) resolve(namespace string, gateway *gwv1beta1.Gateway, listener gwv1beta1.Listener) core.ResolvedRoute { + hostname := listenerHostname(listener) + switch route := r.Route.(type) { case *gwv1alpha2.HTTPRoute: return converter.NewHTTPRouteConverter(converter.HTTPRouteConverterConfig{ Namespace: namespace, Hostname: hostname, - Prefix: prefix, + Prefix: fmt.Sprintf("consul-api-gateway_%s_", gateway.Name), Meta: map[string]string{ "external-source": "consul-api-gateway", - "consul-api-gateway/k8s/Gateway.Name": k8sListener.gateway.Name, - "consul-api-gateway/k8s/Gateway.Namespace": k8sListener.gateway.Namespace, + "consul-api-gateway/k8s/Gateway.Name": gateway.Name, + "consul-api-gateway/k8s/Gateway.Namespace": gateway.Namespace, "consul-api-gateway/k8s/HTTPRoute.Name": r.GetName(), "consul-api-gateway/k8s/HTTPRoute.Namespace": r.GetNamespace(), }, @@ -165,11 +158,11 @@ func (r *K8sRoute) Resolve(listener store.Listener) core.ResolvedRoute { return converter.NewTCPRouteConverter(converter.TCPRouteConverterConfig{ Namespace: namespace, Hostname: hostname, - Prefix: prefix, + Prefix: fmt.Sprintf("consul-api-gateway_%s_", gateway.Name), Meta: map[string]string{ "external-source": "consul-api-gateway", - "consul-api-gateway/k8s/Gateway.Name": k8sListener.gateway.Name, - "consul-api-gateway/k8s/Gateway.Namespace": k8sListener.gateway.Namespace, + "consul-api-gateway/k8s/Gateway.Name": gateway.Name, + "consul-api-gateway/k8s/Gateway.Namespace": gateway.Namespace, "consul-api-gateway/k8s/TCPRoute.Name": r.GetName(), "consul-api-gateway/k8s/TCPRoute.Namespace": r.GetNamespace(), }, @@ -197,32 +190,16 @@ func (r *K8sRoute) Validate(ctx context.Context) error { return r.validator.Validate(ctx, r.RouteState, r.Route) } -func (r *K8sRoute) OnBindFailed(err error, gateway store.Gateway) { - k8sGateway, ok := gateway.(*K8sGateway) - if ok { - id, found := r.parentKeyForGateway(utils.NamespacedName(k8sGateway.Gateway)) - if found { - r.RouteState.ParentStatuses.BindFailed(r.RouteState.ResolutionErrors, err, id) - } - } -} - -func (r *K8sRoute) OnBound(gateway store.Gateway) { - k8sGateway, ok := gateway.(*K8sGateway) - if ok { - id, found := r.parentKeyForGateway(utils.NamespacedName(k8sGateway.Gateway)) - if found { - r.RouteState.ParentStatuses.Bound(id) - } - } -} - func (r *K8sRoute) OnGatewayRemoved(gateway store.Gateway) { k8sGateway, ok := gateway.(*K8sGateway) if ok { - id, found := r.parentKeyForGateway(utils.NamespacedName(k8sGateway.Gateway)) - if found { - r.RouteState.ParentStatuses.Remove(id) + parent := utils.NamespacedName(k8sGateway.Gateway) + for _, p := range r.Parents() { + gatewayName, isGateway := utils.ReferencesGateway(r.GetNamespace(), p) + if isGateway && gatewayName == parent { + r.RouteState.Remove(p) + return + } } } } diff --git a/internal/k8s/reconciler/route_test.go b/internal/k8s/reconciler/route_test.go index 1962c231a..28f122ca7 100644 --- a/internal/k8s/reconciler/route_test.go +++ b/internal/k8s/reconciler/route_test.go @@ -16,7 +16,6 @@ import ( gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" clientMocks "github.com/hashicorp/consul-api-gateway/internal/k8s/gatewayclient/mocks" - "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/state" ) func TestRouteID(t *testing.T) { @@ -145,7 +144,7 @@ func TestRouteMatchesHostname(t *testing.T) { }, }, K8sRouteConfig{ Logger: hclog.NewNullLogger(), - }).MatchesHostname(&hostname)) + }).matchesHostname(&hostname)) require.False(t, newK8sRoute(&gwv1alpha2.HTTPRoute{ Spec: gwv1alpha2.HTTPRouteSpec{ @@ -153,18 +152,22 @@ func TestRouteMatchesHostname(t *testing.T) { }, }, K8sRouteConfig{ Logger: hclog.NewNullLogger(), - }).MatchesHostname(&hostname)) + }).matchesHostname(&hostname)) // check where the underlying route doesn't implement // a matching routine require.True(t, newK8sRoute(&gwv1alpha2.TCPRoute{}, K8sRouteConfig{ Logger: hclog.NewNullLogger(), - }).MatchesHostname(&hostname)) + }).matchesHostname(&hostname)) } func TestRouteResolve(t *testing.T) { t.Parallel() + factory := NewFactory(FactoryConfig{ + Logger: hclog.NewNullLogger(), + }) + gateway := &gwv1beta1.Gateway{ ObjectMeta: meta.ObjectMeta{ Name: "expected", @@ -172,35 +175,14 @@ func TestRouteResolve(t *testing.T) { } listener := gwv1beta1.Listener{} - require.Nil(t, newK8sRoute(&gwv1alpha2.HTTPRoute{}, K8sRouteConfig{ - Logger: hclog.NewNullLogger(), - }).Resolve(nil)) - - require.Nil(t, newK8sRoute(&core.Pod{}, K8sRouteConfig{ - Logger: hclog.NewNullLogger(), - }).Resolve(NewK8sListener(gateway, listener, K8sListenerConfig{ - Logger: hclog.NewNullLogger(), - State: &state.ListenerState{}, - }))) + require.Nil(t, factory.NewRoute(&core.Pod{}).resolve("", gateway, listener)) - require.NotNil(t, newK8sRoute(&gwv1alpha2.HTTPRoute{}, K8sRouteConfig{ - Logger: hclog.NewNullLogger(), - }).Resolve(NewK8sListener(gateway, listener, K8sListenerConfig{ - Logger: hclog.NewNullLogger(), - State: &state.ListenerState{}, - }))) + require.NotNil(t, factory.NewRoute(&gwv1alpha2.HTTPRoute{}).resolve("", gateway, listener)) } func TestRouteSyncStatus(t *testing.T) { t.Parallel() - gateway := newK8sGateway(&gwv1beta1.Gateway{ - ObjectMeta: meta.ObjectMeta{ - Name: "expected", - }, - }, K8sGatewayConfig{ - Logger: hclog.NewNullLogger(), - }) inner := &gwv1alpha2.HTTPRoute{ Spec: gwv1alpha2.HTTPRouteSpec{ CommonRouteSpec: gwv1alpha2.CommonRouteSpec{ @@ -245,7 +227,7 @@ func TestRouteSyncStatus(t *testing.T) { Logger: logger, Client: client, }) - route.OnBound(gateway) + route.RouteState.Bound(gwv1alpha2.ParentReference{Name: "expected"}) expected := errors.New("expected") client.EXPECT().UpdateStatus(gomock.Any(), inner).Return(expected) diff --git a/internal/k8s/reconciler/state/route.go b/internal/k8s/reconciler/state/route.go index eb9c091f4..0018b9750 100644 --- a/internal/k8s/reconciler/state/route.go +++ b/internal/k8s/reconciler/state/route.go @@ -1,6 +1,9 @@ package state import ( + gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" + + "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/common" "github.com/hashicorp/consul-api-gateway/internal/k8s/reconciler/status" "github.com/hashicorp/consul-api-gateway/internal/k8s/service" ) @@ -20,3 +23,15 @@ func NewRouteState() *RouteState { ParentStatuses: make(status.RouteStatuses), } } + +func (r *RouteState) BindFailed(err error, ref gwv1alpha2.ParentReference) { + r.ParentStatuses.BindFailed(r.ResolutionErrors, err, common.AsJSON(ref)) +} + +func (r *RouteState) Bound(ref gwv1alpha2.ParentReference) { + r.ParentStatuses.Bound(common.AsJSON(ref)) +} + +func (r *RouteState) Remove(ref gwv1alpha2.ParentReference) { + r.ParentStatuses.Remove(common.AsJSON(ref)) +} diff --git a/internal/k8s/reconciler/utils.go b/internal/k8s/reconciler/utils.go index 3b2077087..666f85ea1 100644 --- a/internal/k8s/reconciler/utils.go +++ b/internal/k8s/reconciler/utils.go @@ -1,35 +1,18 @@ package reconciler import ( - "context" "encoding/json" - "fmt" "reflect" "sort" "strings" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - klabels "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" - - "github.com/hashicorp/consul-api-gateway/internal/k8s/gatewayclient" -) - -const ( - // NamespaceNameLabel represents that label added automatically to namespaces is newer Kubernetes clusters - NamespaceNameLabel = "kubernetes.io/metadata.name" ) -func routeMatchesListener(listenerName gwv1beta1.SectionName, routeSectionName *gwv1alpha2.SectionName) (can bool, must bool) { - if routeSectionName == nil { - return true, false - } - return string(listenerName) == string(*routeSectionName), true -} - func routeMatchesListenerHostname(listenerHostname *gwv1beta1.Hostname, hostnames []gwv1alpha2.Hostname) bool { if listenerHostname == nil || len(hostnames) == 0 { return true @@ -66,76 +49,6 @@ func hostnamesMatch(a gwv1alpha2.Hostname, b gwv1beta1.Hostname) bool { return string(a) == string(b) } -func routeKindIsAllowedForListener(kinds []gwv1beta1.RouteGroupKind, route *K8sRoute) bool { - if kinds == nil { - return true - } - - gvk := route.GroupVersionKind() - for _, kind := range kinds { - group := gwv1beta1.GroupName - if kind.Group != nil && *kind.Group != "" { - group = string(*kind.Group) - } - if string(kind.Kind) == gvk.Kind && group == gvk.Group { - return true - } - } - - return false -} - -// routeAllowedForListenerNamespaces determines whether the route is allowed -// to bind to the Gateway based on the AllowedRoutes namespace selectors. -func routeAllowedForListenerNamespaces(ctx context.Context, gatewayNS string, allowedRoutes *gwv1beta1.AllowedRoutes, route *K8sRoute, c gatewayclient.Client) (bool, error) { - var namespaceSelector *gwv1beta1.RouteNamespaces - if allowedRoutes != nil { - // check gateway namespace - namespaceSelector = allowedRoutes.Namespaces - } - - // set default if namespace selector is nil - from := gwv1beta1.NamespacesFromSame - if namespaceSelector != nil && namespaceSelector.From != nil && *namespaceSelector.From != "" { - from = *namespaceSelector.From - } - switch from { - case gwv1beta1.NamespacesFromAll: - return true, nil - case gwv1beta1.NamespacesFromSame: - return gatewayNS == route.GetNamespace(), nil - case gwv1beta1.NamespacesFromSelector: - namespaceSelector, err := metav1.LabelSelectorAsSelector(namespaceSelector.Selector) - if err != nil { - return false, fmt.Errorf("error parsing label selector: %w", err) - } - - // retrieve the route's namespace and determine whether selector matches - namespace, err := c.GetNamespace(ctx, types.NamespacedName{Name: route.GetNamespace()}) - if err != nil { - return false, fmt.Errorf("error retrieving namespace for route: %w", err) - } - - return namespaceSelector.Matches(toNamespaceSet(namespace.GetName(), namespace.GetLabels())), nil - } - return false, nil -} - -func toNamespaceSet(name string, labels map[string]string) klabels.Labels { - // If namespace label is not set, implicitly insert it to support older Kubernetes versions - if labels[NamespaceNameLabel] == name { - // Already set, avoid copies - return klabels.Set(labels) - } - // First we need a copy to not modify the underlying object - ret := make(map[string]string, len(labels)+1) - for k, v := range labels { - ret[k] = v - } - ret[NamespaceNameLabel] = name - return klabels.Set(ret) -} - func sortParents(parents []gwv1alpha2.RouteParentStatus) []gwv1alpha2.RouteParentStatus { for _, parent := range parents { sort.SliceStable(parent.Conditions, func(i, j int) bool { diff --git a/internal/k8s/reconciler/utils_test.go b/internal/k8s/reconciler/utils_test.go index 0f3f06687..78673c4d8 100644 --- a/internal/k8s/reconciler/utils_test.go +++ b/internal/k8s/reconciler/utils_test.go @@ -1,41 +1,14 @@ package reconciler import ( - "context" "testing" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" - core "k8s.io/api/core/v1" meta "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/types" gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" - - "github.com/hashicorp/go-hclog" - - "github.com/hashicorp/consul-api-gateway/internal/k8s/gatewayclient/mocks" ) -func TestRouteMatchesListener(t *testing.T) { - t.Parallel() - - listenerName := gwv1beta1.SectionName("name") - routeSectionName := gwv1alpha2.SectionName("name") - can, must := routeMatchesListener(listenerName, &routeSectionName) - require.True(t, can) - require.True(t, must) - - can, must = routeMatchesListener(listenerName, nil) - require.True(t, can) - require.False(t, must) - - can, must = routeMatchesListener(gwv1beta1.SectionName("other"), &routeSectionName) - require.False(t, can) - require.True(t, must) -} - func TestRouteMatchesListenerHostname(t *testing.T) { t.Parallel() @@ -64,166 +37,6 @@ func TestHostnamesMatch(t *testing.T) { require.True(t, hostnamesMatch("a.b.test", "a.b.test")) } -func TestRouteKindIsAllowedForListener(t *testing.T) { - t.Parallel() - - factory := NewFactory(FactoryConfig{ - Logger: hclog.NewNullLogger(), - }) - - routeMeta := meta.TypeMeta{} - routeMeta.SetGroupVersionKind(schema.GroupVersionKind{ - Group: gwv1alpha2.GroupVersion.Group, - Version: gwv1alpha2.GroupVersion.Version, - Kind: "HTTPRoute", - }) - - require.True(t, routeKindIsAllowedForListener( - []gwv1beta1.RouteGroupKind{{ - Group: (*gwv1beta1.Group)(&gwv1alpha2.GroupVersion.Group), - Kind: "HTTPRoute", - }}, - factory.NewRoute(&gwv1alpha2.HTTPRoute{TypeMeta: routeMeta}))) - - require.False(t, routeKindIsAllowedForListener( - []gwv1beta1.RouteGroupKind{{ - Group: (*gwv1beta1.Group)(&gwv1alpha2.GroupVersion.Group), - Kind: "TCPRoute", - }}, - factory.NewRoute(&gwv1alpha2.HTTPRoute{TypeMeta: routeMeta}))) -} - -func TestRouteAllowedForListenerNamespaces(t *testing.T) { - t.Parallel() - - factory := NewFactory(FactoryConfig{ - Logger: hclog.NewNullLogger(), - }) - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - client := mocks.NewMockClient(ctrl) - - // same - same := gwv1beta1.NamespacesFromSame - - allowed, err := routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &same, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "expected", - }, - }), client) - require.NoError(t, err) - require.True(t, allowed) - - allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &same, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "other", - }, - }), client) - require.NoError(t, err) - require.False(t, allowed) - - // all - all := gwv1beta1.NamespacesFromAll - allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &all, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "other", - }, - }), client) - require.NoError(t, err) - require.True(t, allowed) - - // selector - selector := gwv1beta1.NamespacesFromSelector - - matchingNamespace := &core.Namespace{ - ObjectMeta: meta.ObjectMeta{ - Labels: map[string]string{ - "label": "test", - "kubernetes.io/metadata.name": "expected", - }}} - invalidNamespace := &core.Namespace{ObjectMeta: meta.ObjectMeta{Labels: map[string]string{}}} - - client.EXPECT().GetNamespace(context.Background(), types.NamespacedName{Name: "expected"}).Return(invalidNamespace, nil).Times(1) - allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &selector, - Selector: &meta.LabelSelector{ - MatchLabels: map[string]string{ - "label": "test", - }, - }, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "expected", - }, - }), client) - require.NoError(t, err) - require.False(t, allowed) - - client.EXPECT().GetNamespace(context.Background(), types.NamespacedName{Name: "expected"}).Return(matchingNamespace, nil).Times(1) - allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &selector, - Selector: &meta.LabelSelector{ - MatchLabels: map[string]string{ - "label": "test", - }, - }, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "expected", - }, - }), client) - require.NoError(t, err) - require.True(t, allowed) - - _, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &selector, - Selector: &meta.LabelSelector{ - MatchExpressions: []meta.LabelSelectorRequirement{{ - Key: "test", - Operator: meta.LabelSelectorOperator("invalid"), - }}, - }, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "expected", - }, - }), client) - require.Error(t, err) - - // unknown - unknown := gwv1beta1.FromNamespaces("unknown") - allowed, err = routeAllowedForListenerNamespaces(context.Background(), "expected", &gwv1beta1.AllowedRoutes{ - Namespaces: &gwv1beta1.RouteNamespaces{ - From: &unknown, - }, - }, factory.NewRoute(&gwv1alpha2.HTTPRoute{ - ObjectMeta: meta.ObjectMeta{ - Namespace: "expected", - }, - }), client) - require.NoError(t, err) - require.False(t, allowed) -} - func TestConditionEqual(t *testing.T) { t.Parallel() diff --git a/internal/store/interfaces.go b/internal/store/interfaces.go index a55fdf9ac..ec4e8cd33 100644 --- a/internal/store/interfaces.go +++ b/internal/store/interfaces.go @@ -10,14 +10,6 @@ import ( type CompareResult int -const ( - CompareResultInvalid CompareResult = iota - CompareResultNewer - CompareResultNotEqual - CompareResultEqual - CompareResultStatusNotEqual -) - // StatusTrackingGateway is an optional extension // of Gateway. If supported by a Store, when // a Gateway is synced to an external location, @@ -32,10 +24,10 @@ type StatusTrackingGateway interface { // Gateway describes a gateway. type Gateway interface { ID() core.GatewayID - Meta() map[string]string - ShouldUpdate(other Gateway) bool - Listeners() []Listener - ShouldBind(route Route) bool + Bind(ctx context.Context, route Route) []string + Remove(ctx context.Context, id string) error + Resolve() core.ResolvedGateway + CanFetchSecrets(ctx context.Context, secrets []string) (bool, error) } // ListenerConfig contains the common configuration @@ -77,8 +69,6 @@ type StatusTrackingRoute interface { Route SyncStatus(ctx context.Context) error - OnBound(gateway Gateway) - OnBindFailed(err error, gateway Gateway) OnGatewayRemoved(gateway Gateway) } @@ -86,13 +76,11 @@ type StatusTrackingRoute interface { // source integrations type Route interface { ID() string - Resolve(listener Listener) core.ResolvedRoute } // Store is used for persisting and querying gateways and routes type Store interface { - GatewayExists(ctx context.Context, id core.GatewayID) (bool, error) - CanFetchSecrets(ctx context.Context, id core.GatewayID, secrets []string) (bool, error) + GetGateway(ctx context.Context, id core.GatewayID) (Gateway, error) DeleteGateway(ctx context.Context, id core.GatewayID) error UpsertGateway(ctx context.Context, gateway Gateway, updateConditionFn func(current Gateway) bool) error DeleteRoute(ctx context.Context, id string) error diff --git a/internal/store/memory/gateway.go b/internal/store/memory/gateway.go deleted file mode 100644 index e61a45f10..000000000 --- a/internal/store/memory/gateway.go +++ /dev/null @@ -1,153 +0,0 @@ -package memory - -import ( - "context" - - "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-multierror" - - "github.com/hashicorp/consul-api-gateway/internal/core" - "github.com/hashicorp/consul-api-gateway/internal/store" -) - -type gatewayState struct { - store.Gateway - - logger hclog.Logger - adapter core.SyncAdapter - listeners map[string]*listenerState - secrets map[string]struct{} - needsSync bool -} - -// newGatewayState creates a bound gateway -func newGatewayState(logger hclog.Logger, gateway store.Gateway, adapter core.SyncAdapter) *gatewayState { - id := gateway.ID() - - secrets := make(map[string]struct{}) - gatewayLogger := logger.With("gateway.consul.namespace", id.ConsulNamespace, "gateway.consul.service", id.Service) - listeners := make(map[string]*listenerState) - for _, listener := range gateway.Listeners() { - for _, cert := range listener.Config().TLS.Certificates { - secrets[cert] = struct{}{} - } - listeners[listener.ID()] = newListenerState(gatewayLogger, gateway, listener) - } - - return &gatewayState{ - Gateway: gateway, - logger: gatewayLogger, - adapter: adapter, - listeners: listeners, - secrets: secrets, - needsSync: false, - } -} - -// Remove removes a route from the gateway's listeners if -// it is bound to a listener -func (g *gatewayState) Remove(id string) { - for _, listener := range g.listeners { - listener.RemoveRoute(id) - } -} - -func (g *gatewayState) TryBind(ctx context.Context, route store.Route) { - g.logger.Trace("checking if route can bind to gateway", "route", route.ID()) - if g.ShouldBind(route) { - bound := false - var bindError error - for _, l := range g.listeners { - g.logger.Trace("checking if route can bind to listener", "listener", l.name, "route", route.ID()) - canBind, err := l.CanBind(ctx, route) - if err != nil { - // consider each route distinct for the purposes of binding - g.logger.Debug("error binding route to gateway", "error", err, "route", route.ID()) - l.RemoveRoute(route.ID()) - bindError = multierror.Append(bindError, err) - } - if canBind { - g.logger.Trace("setting listener route", "listener", l.name, "route", route.ID()) - l.SetRoute(route) - bound = true - } - } - if tracker, ok := route.(store.StatusTrackingRoute); ok { - if !bound { - tracker.OnBindFailed(bindError, g.Gateway) - } else { - tracker.OnBound(g.Gateway) - } - } - } else { - // Clean up route from gateway listeners if ParentRef no longer - // references gateway - g.Remove(route.ID()) - } -} - -func (g *gatewayState) ShouldUpdate(other store.Gateway) bool { - if other == nil { - return false - } - if g == nil { - return true - } - - return g.Gateway.ShouldUpdate(other) -} - -func (g *gatewayState) ShouldSync(ctx context.Context) bool { - if g.needsSync { - return true - } - - for _, listener := range g.listeners { - if listener.ShouldSync() { - return true - } - } - - return false -} - -func (g *gatewayState) MarkSynced() { - g.needsSync = false -} - -func (g *gatewayState) Sync(ctx context.Context) (bool, error) { - didSync := false - - if g.ShouldSync(ctx) { - g.logger.Trace("syncing gateway") - if err := g.sync(ctx); err != nil { - return false, err - } - didSync = true - } - - g.MarkSynced() - for _, listener := range g.listeners { - listener.MarkSynced() - } - - return didSync, nil -} - -func (g *gatewayState) sync(ctx context.Context) error { - return g.adapter.Sync(ctx, g.Resolve()) -} - -func (g *gatewayState) Resolve() core.ResolvedGateway { - listeners := []core.ResolvedListener{} - for _, listener := range g.listeners { - if listener.Listener.IsValid() { - listeners = append(listeners, listener.Resolve()) - } - } - return core.ResolvedGateway{ - ID: g.ID(), - Meta: g.Meta(), - Listeners: listeners, - } -} diff --git a/internal/store/memory/listener.go b/internal/store/memory/listener.go deleted file mode 100644 index b6db5e498..000000000 --- a/internal/store/memory/listener.go +++ /dev/null @@ -1,112 +0,0 @@ -package memory - -import ( - "reflect" - - "github.com/hashicorp/go-hclog" - - "github.com/hashicorp/consul-api-gateway/internal/core" - "github.com/hashicorp/consul-api-gateway/internal/store" -) - -const ( - defaultListenerName = "default" -) - -// boundListener wraps a listener and its set of routes -type listenerState struct { - store.Listener - - gateway store.Gateway - - logger hclog.Logger - name string - hostname string - port int - protocol string - - routes map[string]core.ResolvedRoute - - needsSync bool -} - -func newListenerState(logger hclog.Logger, gateway store.Gateway, listener store.Listener) *listenerState { - listenerConfig := listener.Config() - - name := defaultListenerName - if listenerConfig.Name != "" { - name = listenerConfig.Name - } - hostname := "" - if listenerConfig.Hostname != "" { - hostname = listenerConfig.Hostname - } - - return &listenerState{ - Listener: listener, - gateway: gateway, - logger: logger.With("listener", name), - name: name, - port: listenerConfig.Port, - protocol: listenerConfig.Protocol, - hostname: hostname, - routes: make(map[string]core.ResolvedRoute), - needsSync: true, - } -} - -func (l *listenerState) RemoveRoute(id string) { - if _, found := l.routes[id]; !found { - return - } - l.logger.Trace("removing route from listener", "route", id) - if tracker, ok := l.Listener.(store.RouteTrackingListener); ok { - tracker.OnRouteRemoved(id) - } - - l.needsSync = true - delete(l.routes, id) -} - -func (l *listenerState) SetRoute(route store.Route) { - l.logger.Trace("setting route on listener", "route", route.ID()) - if resolved := route.Resolve(l.Listener); resolved != nil { - stored, found := l.routes[route.ID()] - if found && reflect.DeepEqual(stored, resolved) { - // don't bother updating if the route is the same - return - } - if tracker, ok := l.Listener.(store.RouteTrackingListener); ok { - if !found { - tracker.OnRouteAdded(route) - } - } - - l.routes[route.ID()] = resolved - - l.needsSync = true - } -} - -func (l *listenerState) ShouldSync() bool { - return l.needsSync -} - -func (l *listenerState) MarkSynced() { - l.needsSync = false -} - -func (l *listenerState) Resolve() core.ResolvedListener { - routes := []core.ResolvedRoute{} - for _, route := range l.routes { - routes = append(routes, route) - } - return core.ResolvedListener{ - Name: l.name, - Hostname: l.hostname, - Port: l.port, - Protocol: l.protocol, - TLS: l.Listener.Config().TLS, - Routes: routes, - } -} diff --git a/internal/store/memory/store.go b/internal/store/memory/store.go index ca641b6cb..ed1e528a7 100644 --- a/internal/store/memory/store.go +++ b/internal/store/memory/store.go @@ -2,7 +2,6 @@ package memory import ( "context" - "reflect" "sync" "sync/atomic" "time" @@ -24,7 +23,7 @@ type Store struct { logger hclog.Logger adapter core.SyncAdapter - gateways map[core.GatewayID]*gatewayState + gateways map[core.GatewayID]store.Gateway routes map[string]store.Route // This mutex acts as a stop-the-world type global mutex, as the store is a singleton. @@ -46,7 +45,7 @@ func NewStore(config StoreConfig) *Store { logger: config.Logger, adapter: config.Adapter, routes: make(map[string]store.Route), - gateways: make(map[core.GatewayID]*gatewayState), + gateways: make(map[core.GatewayID]store.Gateway), } } @@ -58,22 +57,6 @@ func (s *Store) GatewayExists(ctx context.Context, id core.GatewayID) (bool, err return found, nil } -func (s *Store) CanFetchSecrets(ctx context.Context, id core.GatewayID, secrets []string) (bool, error) { - s.mutex.RLock() - defer s.mutex.RUnlock() - - gateway, found := s.gateways[id] - if !found { - return false, nil - } - for _, secret := range secrets { - if _, found := gateway.secrets[secret]; !found { - return false, nil - } - } - return true, nil -} - func (s *Store) GetGateway(ctx context.Context, id core.GatewayID) (store.Gateway, error) { s.mutex.RLock() defer s.mutex.RUnlock() @@ -82,20 +65,20 @@ func (s *Store) GetGateway(ctx context.Context, id core.GatewayID) (store.Gatewa if !found { return nil, nil } - return gateway.Gateway, nil + return gateway, nil } -func (s *Store) syncGateway(ctx context.Context, gateway *gatewayState) error { - if tracker, ok := gateway.Gateway.(store.StatusTrackingGateway); ok { +func (s *Store) syncGateway(ctx context.Context, gateway store.Gateway) error { + if tracker, ok := gateway.(store.StatusTrackingGateway); ok { return tracker.TrackSync(ctx, func() (bool, error) { - return gateway.Sync(ctx) + return s.adapter.Sync(ctx, gateway.Resolve()) }) } - _, err := gateway.Sync(ctx) + _, err := s.adapter.Sync(ctx, gateway.Resolve()) return err } -func (s *Store) syncGateways(ctx context.Context, gateways ...*gatewayState) error { +func (s *Store) syncGateways(ctx context.Context, gateways ...store.Gateway) error { var syncGroup multierror.Group for _, gw := range gateways { @@ -135,7 +118,7 @@ func (s *Store) syncRouteStatuses(ctx context.Context) error { // access any potentially references stored from previous callbacks in the // status updating callbacks in our interfaces -- otherwise proper locking // is needed. -func (s *Store) sync(ctx context.Context, gateways ...*gatewayState) error { +func (s *Store) sync(ctx context.Context, gateways ...store.Gateway) error { var syncGroup multierror.Group if gateways == nil { @@ -147,13 +130,12 @@ func (s *Store) sync(ctx context.Context, gateways ...*gatewayState) error { syncGroup.Go(func() error { return s.syncGateways(ctx, gateways...) }) + syncGroup.Go(func() error { return s.syncRouteStatuses(ctx) }) - if err := syncGroup.Wait().ErrorOrNil(); err != nil { - return err - } - return nil + + return syncGroup.Wait().ErrorOrNil() } func (s *Store) Sync(ctx context.Context) error { @@ -175,12 +157,6 @@ func (s *Store) SyncAtInterval(ctx context.Context) { case <-ticker.C: s.mutex.Lock() - // Force each gateway to sync its state even though listeners - // on the gateway may not be marked as needing a sync right now - for _, gateway := range s.gateways { - gateway.needsSync = true - } - if err := s.sync(ctx); err != nil { s.logger.Warn("An error occurred during memory store sync, some changes may be out of sync", "error", err) } else { @@ -198,7 +174,7 @@ func (s *Store) DeleteRoute(ctx context.Context, id string) error { s.logger.Trace("deleting route", "id", id) for _, gateway := range s.gateways { - gateway.Remove(id) + gateway.Remove(ctx, id) } delete(s.routes, id) @@ -222,7 +198,7 @@ func (s *Store) UpsertRoute(ctx context.Context, route store.Route, updateCondit // bind to gateways for _, gateway := range s.gateways { - gateway.TryBind(ctx, route) + gateway.Bind(ctx, route) } // sync the gateways to consul and route statuses to k8s @@ -236,29 +212,17 @@ func (s *Store) UpsertGateway(ctx context.Context, gateway store.Gateway, update id := gateway.ID() current, found := s.gateways[id] - var currentGW store.Gateway - if found { - currentGW = current.Gateway - } - if updateConditionFn != nil && !updateConditionFn(currentGW) { + if updateConditionFn != nil && !updateConditionFn(current) { // No-op return nil } - updated := newGatewayState(s.logger, gateway, s.adapter) - s.gateways[id] = updated + s.gateways[id] = gateway // bind routes to this gateway for _, route := range s.routes { - updated.TryBind(ctx, route) - } - - if found && reflect.DeepEqual(current.Resolve(), updated.Resolve()) { - // we have the exact same render tree, mark the gateway as already synced - for _, listener := range updated.listeners { - listener.MarkSynced() - } + gateway.Bind(ctx, route) } if !found { @@ -268,7 +232,7 @@ func (s *Store) UpsertGateway(ctx context.Context, gateway store.Gateway, update } // sync the gateway to consul and any updated route statuses - return s.sync(ctx, s.gateways[id]) + return s.sync(ctx, gateway) } func (s *Store) DeleteGateway(ctx context.Context, id core.GatewayID) error { diff --git a/internal/store/mocks/interfaces.go b/internal/store/mocks/interfaces.go index 3c0bb6fbc..b836716b1 100644 --- a/internal/store/mocks/interfaces.go +++ b/internal/store/mocks/interfaces.go @@ -36,74 +36,75 @@ func (m *MockStatusTrackingGateway) EXPECT() *MockStatusTrackingGatewayMockRecor return m.recorder } -// ID mocks base method. -func (m *MockStatusTrackingGateway) ID() core.GatewayID { +// Bind mocks base method. +func (m *MockStatusTrackingGateway) Bind(ctx context.Context, route store.Route) []string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ID") - ret0, _ := ret[0].(core.GatewayID) + ret := m.ctrl.Call(m, "Bind", ctx, route) + ret0, _ := ret[0].([]string) return ret0 } -// ID indicates an expected call of ID. -func (mr *MockStatusTrackingGatewayMockRecorder) ID() *gomock.Call { +// Bind indicates an expected call of Bind. +func (mr *MockStatusTrackingGatewayMockRecorder) Bind(ctx, route interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockStatusTrackingGateway)(nil).ID)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockStatusTrackingGateway)(nil).Bind), ctx, route) } -// Listeners mocks base method. -func (m *MockStatusTrackingGateway) Listeners() []store.Listener { +// CanFetchSecrets mocks base method. +func (m *MockStatusTrackingGateway) CanFetchSecrets(ctx context.Context, secrets []string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Listeners") - ret0, _ := ret[0].([]store.Listener) - return ret0 + ret := m.ctrl.Call(m, "CanFetchSecrets", ctx, secrets) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// Listeners indicates an expected call of Listeners. -func (mr *MockStatusTrackingGatewayMockRecorder) Listeners() *gomock.Call { +// CanFetchSecrets indicates an expected call of CanFetchSecrets. +func (mr *MockStatusTrackingGatewayMockRecorder) CanFetchSecrets(ctx, secrets interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Listeners", reflect.TypeOf((*MockStatusTrackingGateway)(nil).Listeners)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanFetchSecrets", reflect.TypeOf((*MockStatusTrackingGateway)(nil).CanFetchSecrets), ctx, secrets) } -// Meta mocks base method. -func (m *MockStatusTrackingGateway) Meta() map[string]string { +// ID mocks base method. +func (m *MockStatusTrackingGateway) ID() core.GatewayID { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Meta") - ret0, _ := ret[0].(map[string]string) + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(core.GatewayID) return ret0 } -// Meta indicates an expected call of Meta. -func (mr *MockStatusTrackingGatewayMockRecorder) Meta() *gomock.Call { +// ID indicates an expected call of ID. +func (mr *MockStatusTrackingGatewayMockRecorder) ID() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Meta", reflect.TypeOf((*MockStatusTrackingGateway)(nil).Meta)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockStatusTrackingGateway)(nil).ID)) } -// ShouldBind mocks base method. -func (m *MockStatusTrackingGateway) ShouldBind(route store.Route) bool { +// Remove mocks base method. +func (m *MockStatusTrackingGateway) Remove(ctx context.Context, id string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ShouldBind", route) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Remove", ctx, id) + ret0, _ := ret[0].(error) return ret0 } -// ShouldBind indicates an expected call of ShouldBind. -func (mr *MockStatusTrackingGatewayMockRecorder) ShouldBind(route interface{}) *gomock.Call { +// Remove indicates an expected call of Remove. +func (mr *MockStatusTrackingGatewayMockRecorder) Remove(ctx, id interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldBind", reflect.TypeOf((*MockStatusTrackingGateway)(nil).ShouldBind), route) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockStatusTrackingGateway)(nil).Remove), ctx, id) } -// ShouldUpdate mocks base method. -func (m *MockStatusTrackingGateway) ShouldUpdate(other store.Gateway) bool { +// Resolve mocks base method. +func (m *MockStatusTrackingGateway) Resolve() core.ResolvedGateway { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ShouldUpdate", other) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Resolve") + ret0, _ := ret[0].(core.ResolvedGateway) return ret0 } -// ShouldUpdate indicates an expected call of ShouldUpdate. -func (mr *MockStatusTrackingGatewayMockRecorder) ShouldUpdate(other interface{}) *gomock.Call { +// Resolve indicates an expected call of Resolve. +func (mr *MockStatusTrackingGatewayMockRecorder) Resolve() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldUpdate", reflect.TypeOf((*MockStatusTrackingGateway)(nil).ShouldUpdate), other) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockStatusTrackingGateway)(nil).Resolve)) } // TrackSync mocks base method. @@ -143,74 +144,75 @@ func (m *MockGateway) EXPECT() *MockGatewayMockRecorder { return m.recorder } -// ID mocks base method. -func (m *MockGateway) ID() core.GatewayID { +// Bind mocks base method. +func (m *MockGateway) Bind(ctx context.Context, route store.Route) []string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ID") - ret0, _ := ret[0].(core.GatewayID) + ret := m.ctrl.Call(m, "Bind", ctx, route) + ret0, _ := ret[0].([]string) return ret0 } -// ID indicates an expected call of ID. -func (mr *MockGatewayMockRecorder) ID() *gomock.Call { +// Bind indicates an expected call of Bind. +func (mr *MockGatewayMockRecorder) Bind(ctx, route interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockGateway)(nil).ID)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockGateway)(nil).Bind), ctx, route) } -// Listeners mocks base method. -func (m *MockGateway) Listeners() []store.Listener { +// CanFetchSecrets mocks base method. +func (m *MockGateway) CanFetchSecrets(ctx context.Context, secrets []string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Listeners") - ret0, _ := ret[0].([]store.Listener) - return ret0 + ret := m.ctrl.Call(m, "CanFetchSecrets", ctx, secrets) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// Listeners indicates an expected call of Listeners. -func (mr *MockGatewayMockRecorder) Listeners() *gomock.Call { +// CanFetchSecrets indicates an expected call of CanFetchSecrets. +func (mr *MockGatewayMockRecorder) CanFetchSecrets(ctx, secrets interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Listeners", reflect.TypeOf((*MockGateway)(nil).Listeners)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanFetchSecrets", reflect.TypeOf((*MockGateway)(nil).CanFetchSecrets), ctx, secrets) } -// Meta mocks base method. -func (m *MockGateway) Meta() map[string]string { +// ID mocks base method. +func (m *MockGateway) ID() core.GatewayID { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Meta") - ret0, _ := ret[0].(map[string]string) + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(core.GatewayID) return ret0 } -// Meta indicates an expected call of Meta. -func (mr *MockGatewayMockRecorder) Meta() *gomock.Call { +// ID indicates an expected call of ID. +func (mr *MockGatewayMockRecorder) ID() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Meta", reflect.TypeOf((*MockGateway)(nil).Meta)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockGateway)(nil).ID)) } -// ShouldBind mocks base method. -func (m *MockGateway) ShouldBind(route store.Route) bool { +// Remove mocks base method. +func (m *MockGateway) Remove(ctx context.Context, id string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ShouldBind", route) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Remove", ctx, id) + ret0, _ := ret[0].(error) return ret0 } -// ShouldBind indicates an expected call of ShouldBind. -func (mr *MockGatewayMockRecorder) ShouldBind(route interface{}) *gomock.Call { +// Remove indicates an expected call of Remove. +func (mr *MockGatewayMockRecorder) Remove(ctx, id interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldBind", reflect.TypeOf((*MockGateway)(nil).ShouldBind), route) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockGateway)(nil).Remove), ctx, id) } -// ShouldUpdate mocks base method. -func (m *MockGateway) ShouldUpdate(other store.Gateway) bool { +// Resolve mocks base method. +func (m *MockGateway) Resolve() core.ResolvedGateway { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ShouldUpdate", other) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "Resolve") + ret0, _ := ret[0].(core.ResolvedGateway) return ret0 } -// ShouldUpdate indicates an expected call of ShouldUpdate. -func (mr *MockGatewayMockRecorder) ShouldUpdate(other interface{}) *gomock.Call { +// Resolve indicates an expected call of Resolve. +func (mr *MockGatewayMockRecorder) Resolve() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldUpdate", reflect.TypeOf((*MockGateway)(nil).ShouldUpdate), other) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockGateway)(nil).Resolve)) } // MockRouteTrackingListener is a mock of RouteTrackingListener interface. @@ -434,30 +436,6 @@ func (mr *MockStatusTrackingRouteMockRecorder) ID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockStatusTrackingRoute)(nil).ID)) } -// OnBindFailed mocks base method. -func (m *MockStatusTrackingRoute) OnBindFailed(err error, gateway store.Gateway) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnBindFailed", err, gateway) -} - -// OnBindFailed indicates an expected call of OnBindFailed. -func (mr *MockStatusTrackingRouteMockRecorder) OnBindFailed(err, gateway interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnBindFailed", reflect.TypeOf((*MockStatusTrackingRoute)(nil).OnBindFailed), err, gateway) -} - -// OnBound mocks base method. -func (m *MockStatusTrackingRoute) OnBound(gateway store.Gateway) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnBound", gateway) -} - -// OnBound indicates an expected call of OnBound. -func (mr *MockStatusTrackingRouteMockRecorder) OnBound(gateway interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnBound", reflect.TypeOf((*MockStatusTrackingRoute)(nil).OnBound), gateway) -} - // OnGatewayRemoved mocks base method. func (m *MockStatusTrackingRoute) OnGatewayRemoved(gateway store.Gateway) { m.ctrl.T.Helper() @@ -470,20 +448,6 @@ func (mr *MockStatusTrackingRouteMockRecorder) OnGatewayRemoved(gateway interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnGatewayRemoved", reflect.TypeOf((*MockStatusTrackingRoute)(nil).OnGatewayRemoved), gateway) } -// Resolve mocks base method. -func (m *MockStatusTrackingRoute) Resolve(listener store.Listener) core.ResolvedRoute { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Resolve", listener) - ret0, _ := ret[0].(core.ResolvedRoute) - return ret0 -} - -// Resolve indicates an expected call of Resolve. -func (mr *MockStatusTrackingRouteMockRecorder) Resolve(listener interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockStatusTrackingRoute)(nil).Resolve), listener) -} - // SyncStatus mocks base method. func (m *MockStatusTrackingRoute) SyncStatus(ctx context.Context) error { m.ctrl.T.Helper() @@ -535,20 +499,6 @@ func (mr *MockRouteMockRecorder) ID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockRoute)(nil).ID)) } -// Resolve mocks base method. -func (m *MockRoute) Resolve(listener store.Listener) core.ResolvedRoute { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Resolve", listener) - ret0, _ := ret[0].(core.ResolvedRoute) - return ret0 -} - -// Resolve indicates an expected call of Resolve. -func (mr *MockRouteMockRecorder) Resolve(listener interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockRoute)(nil).Resolve), listener) -} - // MockStore is a mock of Store interface. type MockStore struct { ctrl *gomock.Controller @@ -572,21 +522,6 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder { return m.recorder } -// CanFetchSecrets mocks base method. -func (m *MockStore) CanFetchSecrets(ctx context.Context, id core.GatewayID, secrets []string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanFetchSecrets", ctx, id, secrets) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CanFetchSecrets indicates an expected call of CanFetchSecrets. -func (mr *MockStoreMockRecorder) CanFetchSecrets(ctx, id, secrets interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanFetchSecrets", reflect.TypeOf((*MockStore)(nil).CanFetchSecrets), ctx, id, secrets) -} - // DeleteGateway mocks base method. func (m *MockStore) DeleteGateway(ctx context.Context, id core.GatewayID) error { m.ctrl.T.Helper() @@ -615,19 +550,19 @@ func (mr *MockStoreMockRecorder) DeleteRoute(ctx, id interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRoute", reflect.TypeOf((*MockStore)(nil).DeleteRoute), ctx, id) } -// GatewayExists mocks base method. -func (m *MockStore) GatewayExists(ctx context.Context, id core.GatewayID) (bool, error) { +// GetGateway mocks base method. +func (m *MockStore) GetGateway(ctx context.Context, id core.GatewayID) (store.Gateway, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GatewayExists", ctx, id) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "GetGateway", ctx, id) + ret0, _ := ret[0].(store.Gateway) ret1, _ := ret[1].(error) return ret0, ret1 } -// GatewayExists indicates an expected call of GatewayExists. -func (mr *MockStoreMockRecorder) GatewayExists(ctx, id interface{}) *gomock.Call { +// GetGateway indicates an expected call of GetGateway. +func (mr *MockStoreMockRecorder) GetGateway(ctx, id interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GatewayExists", reflect.TypeOf((*MockStore)(nil).GatewayExists), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGateway", reflect.TypeOf((*MockStore)(nil).GetGateway), ctx, id) } // Sync mocks base method.