From 5d63a133241cf2782210db8feef06f6240aac2dc Mon Sep 17 00:00:00 2001 From: James Rasell Date: Wed, 16 Mar 2022 15:29:52 +0100 Subject: [PATCH] client: add Nomad service registration implementation. --- client/serviceregistration/nsd/doc.go | 4 + client/serviceregistration/nsd/nsd.go | 396 +++++++++++++++ client/serviceregistration/nsd/nsd_test.go | 553 +++++++++++++++++++++ 3 files changed, 953 insertions(+) create mode 100644 client/serviceregistration/nsd/doc.go create mode 100644 client/serviceregistration/nsd/nsd.go create mode 100644 client/serviceregistration/nsd/nsd_test.go diff --git a/client/serviceregistration/nsd/doc.go b/client/serviceregistration/nsd/doc.go new file mode 100644 index 000000000000..f86c8458d840 --- /dev/null +++ b/client/serviceregistration/nsd/doc.go @@ -0,0 +1,4 @@ +// Package nsd provides Nomad service registration and therefore discovery +// capabilities for Nomad clients. The name nsd was used instead of Nomad to +// avoid conflict with the existing nomad package. +package nsd diff --git a/client/serviceregistration/nsd/nsd.go b/client/serviceregistration/nsd/nsd.go new file mode 100644 index 000000000000..7a45cb880d06 --- /dev/null +++ b/client/serviceregistration/nsd/nsd.go @@ -0,0 +1,396 @@ +package nsd + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/nomad/client/serviceregistration" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/nomad/structs" +) + +type ServiceRegistrationHandler struct { + log hclog.Logger + cfg *ServiceRegistrationHandlerCfg + + // registrationEnabled tracks whether this handler is enabled for + // registrations. This is needed as it's possible a client has its config + // changed whilst allocations using this provider are running on it. In + // this situation we need to be able to deregister services, but disallow + // registering new ones. + registrationEnabled bool + + // minBackoffInterval in the starting and lowest possible interval between + // RPC retries. + minBackoffInterval time.Duration + + // maxBackoffInterval is the maximum duration between consecutive RPC + // retries. Any calculated value that exceeds this, should be reset to this + // value. + maxBackoffInterval time.Duration + + // maxBackoffDuration is the maximum time an RPC will be retried before + // being classed as terminally failed. This doesn't apply to upserts which + // use maxBackoffUpsertDuration. + maxBackoffDuration time.Duration + + // maxBackoffUpsertDuration + maxBackoffUpsertDuration time.Duration + + // shutDownCh coordinates shutting down the handler and any long-running + // processes, such as the RPC retry. + shutDownCh chan struct{} +} + +// ServiceRegistrationHandlerCfg holds critical information used during the +// normal process of the ServiceRegistrationHandler. It is used to keep the +// NewServiceRegistrationHandler function signature small and easy to modify. +type ServiceRegistrationHandlerCfg struct { + + // Enabled tracks whether this client feature is enabled. + Enabled bool + + // Datacenter, NodeID, and Region are all properties of the Nomad client + // and are used to perform RPC requests. + Datacenter string + NodeID string + Region string + + // NodeSecret is the secret ID of the node and is used to authenticate RPC + // requests. + NodeSecret string + + // RPCFn is the client RPC function which is used to perform client to + // server service registration RPC calls. + RPCFn func(method string, args, resp interface{}) error +} + +// NewServiceRegistrationHandler returns a ready to use +// ServiceRegistrationHandler which implements the serviceregistration.Handler +// interface. +func NewServiceRegistrationHandler( + log hclog.Logger, cfg *ServiceRegistrationHandlerCfg) serviceregistration.Handler { + return &ServiceRegistrationHandler{ + cfg: cfg, + log: log.Named("service_registration.nomad"), + registrationEnabled: cfg.Enabled, + minBackoffInterval: time.Second, + maxBackoffInterval: time.Minute, + maxBackoffDuration: time.Hour * 24, + maxBackoffUpsertDuration: 20 * time.Second, + shutDownCh: make(chan struct{}), + } +} + +func (s *ServiceRegistrationHandler) RegisterWorkload(workload *serviceregistration.WorkloadServices) error { + + // Check whether we are enabled or not first. Hitting this likely means + // there is a bug within the implicit constraint, or process using it, as + // that should guard ever placing an allocation on this client. + if !s.registrationEnabled { + return errors.New(`service registration provider "nomad" not enabled`) + } + + // Collect all errors generating service registrations. + var mErr multierror.Error + + registrations := make([]*structs.ServiceRegistration, len(workload.Services)) + + // Iterate over the services and generate a hydrated registration object for + // each. All services are part of a single allocation, therefore we cannot + // have one failure without all becoming a failure. + for i, serviceSpec := range workload.Services { + serviceRegistration, err := s.generateNomadServiceRegistration(serviceSpec, workload) + if err != nil { + mErr.Errors = append(mErr.Errors, err) + } else if mErr.ErrorOrNil() == nil { + registrations[i] = serviceRegistration + } + } + + // If we generated any errors, return this to the caller. + if err := mErr.ErrorOrNil(); err != nil { + return err + } + + args := structs.ServiceRegistrationUpsertRequest{ + Services: registrations, + WriteRequest: structs.WriteRequest{ + Region: s.cfg.Region, + AuthToken: s.cfg.NodeSecret, + }, + } + + var resp structs.ServiceRegistrationUpsertResponse + + return s.retryFunc(func() error { + return s.cfg.RPCFn(structs.ServiceRegistrationUpsertRPCMethod, &args, &resp) + }, structs.ServiceRegistrationUpsertRPCMethod) +} + +// RemoveWorkload iterates the services and removes them from the service +// registration state. +// +// This function works regardless of whether the client has this feature +// enabled. This covers situations where the feature is disabled, yet still has +// allocations which, when stopped need their registrations removed. +func (s *ServiceRegistrationHandler) RemoveWorkload(workload *serviceregistration.WorkloadServices) { + for _, serviceSpec := range workload.Services { + go s.removeWorkload(workload, serviceSpec) + } +} + +func (s *ServiceRegistrationHandler) removeWorkload( + workload *serviceregistration.WorkloadServices, serviceSpec *structs.Service) { + + // Generate the consistent ID for this service, so we know what to remove. + id := serviceregistration.MakeAllocServiceID(workload.AllocID, workload.Name(), serviceSpec) + + deleteArgs := structs.ServiceRegistrationDeleteByIDRequest{ + ID: id, + WriteRequest: structs.WriteRequest{ + Region: s.cfg.Region, + Namespace: workload.Namespace, + AuthToken: s.cfg.NodeSecret, + }, + } + + var deleteResp structs.ServiceRegistrationDeleteByIDResponse + + // Create our function that will be retried. + f := func() error { + err := s.cfg.RPCFn(structs.ServiceRegistrationDeleteByIDRPCMethod, &deleteArgs, &deleteResp) + if err == nil { + return nil + } + + // The Nomad API exposes service registration deletion to handle + // orphaned service registrations. In the event a service is removed + // accidentally that is still running, we will hit this error when we + // eventually want to remove it. We therefore want to handle this, + // while ensuring the operator can see. + if strings.Contains(err.Error(), "service registration not found") { + s.log.Info("attempted to delete non-existent service registration", + "service_id", id, "namespace", workload.Namespace) + return nil + } + + return err + } + + // Perform our function retry, logging any error with enough detail for + // operators to debug properly. + if err := s.retryFunc(f, structs.ServiceRegistrationDeleteByIDRPCMethod); err != nil { + s.log.Error("failed to delete service registration", + "error", err, "service_id", id, "namespace", workload.Namespace) + } +} + +func (s *ServiceRegistrationHandler) UpdateWorkload(old, new *serviceregistration.WorkloadServices) error { + + // Overwrite the workload with the deduplicated versions. + old, new = s.dedupUpdatedWorkload(old, new) + + // Use the register error as an update protection and only ever deregister + // when this has completed successfully. In the event of an error, we can + // return this to the caller stack without modifying state in a weird half + // manner. + if len(new.Services) > 0 { + if err := s.RegisterWorkload(new); err != nil { + return err + } + } + + if len(old.Services) > 0 { + s.RemoveWorkload(old) + } + + return nil +} + +// dedupUpdatedWorkload works through the request old and new workload to +// return a deduplicated set of services. +// +// This is within its own function to make testing easier. +func (s *ServiceRegistrationHandler) dedupUpdatedWorkload( + oldWork, newWork *serviceregistration.WorkloadServices) ( + *serviceregistration.WorkloadServices, *serviceregistration.WorkloadServices) { + + // Create copies of the old and new workload services. These specifically + // ignore the services array so this can be populated as the function + // decides what is needed. + oldCopy := oldWork.Copy() + oldCopy.Services = make([]*structs.Service, 0) + + newCopy := newWork.Copy() + newCopy.Services = make([]*structs.Service, 0) + + // Generate and populate a mapping of the new service registration IDs. + newIDs := make(map[string]*structs.Service, len(newWork.Services)) + + for _, s := range newWork.Services { + newIDs[serviceregistration.MakeAllocServiceID(newWork.AllocID, newWork.Name(), s)] = s + } + + // Iterate through the old services in order to identify whether they can + // be modified solely via upsert, or whether they need to be deleted. + for _, oldService := range oldWork.Services { + + // Generate the service ID of the old service. If this is not found + // within the new mapping then we need to remove it. + oldID := serviceregistration.MakeAllocServiceID(oldWork.AllocID, oldWork.Name(), oldService) + newSvc, ok := newIDs[oldID] + if !ok { + oldCopy.Services = append(oldCopy.Services, oldService) + continue + } + + // Add the new service into the array for upserting and remove its + // entry for the map. Doing it here is efficient as we are already + // inside a loop. + // + // There isn't much point in hashing the old/new services as we would + // still need to ensure the service has previously been registered + // before discarding it from future RPC calls. The Nomad state handles + // performing the diff gracefully, therefore this will still be a + // single RPC. + newCopy.Services = append(newCopy.Services, newSvc) + delete(newIDs, oldID) + } + + // Iterate the remaining new IDs to add them to the registration array. It + // catches any that didn't get added via the previous loop. + for _, newSvc := range newIDs { + newCopy.Services = append(newCopy.Services, newSvc) + } + + return oldCopy, newCopy +} + +// AllocRegistrations is currently a noop implementation as the Nomad provider +// does not support health check which is the sole subsystem caller of this +// function. +func (s *ServiceRegistrationHandler) AllocRegistrations(_ string) (*serviceregistration.AllocRegistration, error) { + return nil, nil +} + +// UpdateTTL is currently a noop implementation as the Nomad provider does not +// support health check which is the sole subsystem caller of this function. +func (s *ServiceRegistrationHandler) UpdateTTL(_, _, _, _ string) error { + return nil +} + +// Shutdown is used to initiate shutdown of the handler. This is specifically +// used to exit any routines running retry functions without leaving them +// orphaned. +func (s *ServiceRegistrationHandler) Shutdown() { close(s.shutDownCh) } + +// retryFunc handles performing retries of a passed function with backoff. The +// initial function will be triggered immediately; any error returned by this +// function should be considered terminal. It is designed to handle RPC calls +// only, with the function wrapper adding flexibility. +func (s *ServiceRegistrationHandler) retryFunc(fn func() error, method string) error { + + // This context is used to enforce the retry timeout. + maxBackoffDuration := s.maxBackoffDuration + if method == structs.ServiceRegistrationUpsertRPCMethod { + maxBackoffDuration = s.maxBackoffUpsertDuration + } + + ctx, cancel := context.WithTimeout(context.TODO(), maxBackoffDuration) + defer cancel() + + // Store the error outside the loop, so we can always return the last error + // recorded from the RPC. + var err error + + // Copy the backoff, so we can make local changes without altering the + // stored base value. + backoff := s.minBackoffInterval + + // Create a new timer, initially set to zero so that it fires straight + // away. + t, stop := helper.NewSafeTimer(0) + defer stop() + + for { + select { + case <-s.shutDownCh: + s.log.Debug("shutting down handler") + return err + case <-ctx.Done(): + return err + case <-t.C: + } + + // Execute the function. + if err = fn(); err == nil { + break + } + + if backoff < s.maxBackoffInterval { + backoff = backoff * 2 + if backoff > s.maxBackoffInterval { + backoff = s.maxBackoffInterval + } + } + + // Log that the RPC failed along with useful context and reset the + // timer using the new backoff value. + s.log.Debug("service registration RPC failed", + "method", method, "retry", backoff, "error", err) + t.Reset(backoff) + } + + return nil +} + +// generateNomadServiceRegistration is a helper to build the Nomad specific +// registration object on a per-service basis. +func (s *ServiceRegistrationHandler) generateNomadServiceRegistration( + serviceSpec *structs.Service, workload *serviceregistration.WorkloadServices) (*structs.ServiceRegistration, error) { + + // Service address modes default to auto. + addrMode := serviceSpec.AddressMode + if addrMode == "" { + addrMode = structs.AddressModeAuto + } + + // Determine the address to advertise based on the mode. + ip, port, err := serviceregistration.GetAddress( + addrMode, serviceSpec.PortLabel, workload.Networks, + workload.DriverNetwork, workload.Ports, workload.NetworkStatus) + if err != nil { + return nil, fmt.Errorf("unable to get address for service %q: %v", serviceSpec.Name, err) + } + + // Build the tags to use for this registration which is a result of whether + // this is a canary, or not. + var tags []string + + if workload.Canary && len(serviceSpec.CanaryTags) > 0 { + tags = make([]string, len(serviceSpec.CanaryTags)) + copy(tags, serviceSpec.CanaryTags) + } else { + tags = make([]string, len(serviceSpec.Tags)) + copy(tags, serviceSpec.Tags) + } + + return &structs.ServiceRegistration{ + ID: serviceregistration.MakeAllocServiceID(workload.AllocID, workload.Name(), serviceSpec), + ServiceName: serviceSpec.Name, + NodeID: s.cfg.NodeID, + JobID: workload.JobID, + AllocID: workload.AllocID, + Namespace: workload.Namespace, + Datacenter: s.cfg.Datacenter, + Tags: tags, + Address: ip, + Port: port, + }, nil +} diff --git a/client/serviceregistration/nsd/nsd_test.go b/client/serviceregistration/nsd/nsd_test.go new file mode 100644 index 000000000000..935c247b73bf --- /dev/null +++ b/client/serviceregistration/nsd/nsd_test.go @@ -0,0 +1,553 @@ +package nsd + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/serviceregistration" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServiceRegistrationHandler_RegisterWorkload(t *testing.T) { + testCases := []struct { + inputCfg *ServiceRegistrationHandlerCfg + inputWorkload *serviceregistration.WorkloadServices + expectedRPCs map[string]int + expectedError error + name string + }{ + { + inputCfg: &ServiceRegistrationHandlerCfg{ + Enabled: false, + }, + inputWorkload: mockWorkload(), + expectedRPCs: map[string]int{}, + expectedError: errors.New(`service registration provider "nomad" not enabled`), + name: "registration disabled", + }, + { + inputCfg: &ServiceRegistrationHandlerCfg{ + Enabled: true, + }, + inputWorkload: mockWorkload(), + expectedRPCs: map[string]int{structs.ServiceRegistrationUpsertRPCMethod: 1}, + expectedError: nil, + name: "registration enabled", + }, + } + + // Create a logger we can use for all tests. + log := hclog.NewNullLogger() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + // Add the mock RPC functionality. + mockRPC := mockRPC{callCounts: map[string]int{}} + tc.inputCfg.RPCFn = mockRPC.RPC + + // Create the handler and run the tests. + h := NewServiceRegistrationHandler(log, tc.inputCfg) + + actualErr := h.RegisterWorkload(tc.inputWorkload) + require.Equal(t, tc.expectedError, actualErr) + require.Equal(t, tc.expectedRPCs, mockRPC.calls()) + }) + } +} + +func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) { + testCases := []struct { + inputCfg *ServiceRegistrationHandlerCfg + inputWorkload *serviceregistration.WorkloadServices + expectedRPCs map[string]int + expectedError error + name string + }{ + { + inputCfg: &ServiceRegistrationHandlerCfg{ + Enabled: false, + }, + inputWorkload: mockWorkload(), + expectedRPCs: map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 2}, + expectedError: nil, + name: "registration disabled multiple services", + }, + { + inputCfg: &ServiceRegistrationHandlerCfg{ + Enabled: true, + }, + inputWorkload: mockWorkload(), + expectedRPCs: map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 2}, + expectedError: nil, + name: "registration enabled multiple services", + }, + } + + // Create a logger we can use for all tests. + log := hclog.NewNullLogger() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + // Add the mock RPC functionality. + mockRPC := mockRPC{callCounts: map[string]int{}} + tc.inputCfg.RPCFn = mockRPC.RPC + + // Create the handler and run the tests. + h := NewServiceRegistrationHandler(log, tc.inputCfg) + + h.RemoveWorkload(tc.inputWorkload) + + require.Eventually(t, func() bool { + return assert.Equal(t, tc.expectedRPCs, mockRPC.calls()) + }, 100*time.Millisecond, 10*time.Millisecond) + }) + } +} + +func TestServiceRegistrationHandler_UpdateWorkload(t *testing.T) { + testCases := []struct { + inputCfg *ServiceRegistrationHandlerCfg + inputOldWorkload *serviceregistration.WorkloadServices + inputNewWorkload *serviceregistration.WorkloadServices + expectedRPCs map[string]int + expectedError error + name string + }{ + { + inputCfg: &ServiceRegistrationHandlerCfg{ + Enabled: true, + }, + inputOldWorkload: mockWorkload(), + inputNewWorkload: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "changed-redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + }, + { + Name: "changed-redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + expectedRPCs: map[string]int{ + structs.ServiceRegistrationUpsertRPCMethod: 1, + structs.ServiceRegistrationDeleteByIDRPCMethod: 2, + }, + expectedError: nil, + name: "delete and upsert", + }, + { + inputCfg: &ServiceRegistrationHandlerCfg{ + Enabled: true, + }, + inputOldWorkload: mockWorkload(), + inputNewWorkload: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + Tags: []string{"foo"}, + }, + { + Name: "redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + Tags: []string{"bar"}, + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + expectedRPCs: map[string]int{ + structs.ServiceRegistrationUpsertRPCMethod: 1, + }, + expectedError: nil, + name: "upsert only", + }, + } + + // Create a logger we can use for all tests. + log := hclog.NewNullLogger() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + // Add the mock RPC functionality. + mockRPC := mockRPC{callCounts: map[string]int{}} + tc.inputCfg.RPCFn = mockRPC.RPC + + // Create the handler and run the tests. + h := NewServiceRegistrationHandler(log, tc.inputCfg) + + require.Equal(t, tc.expectedError, h.UpdateWorkload(tc.inputOldWorkload, tc.inputNewWorkload)) + + require.Eventually(t, func() bool { + return assert.Equal(t, tc.expectedRPCs, mockRPC.calls()) + }, 100*time.Millisecond, 10*time.Millisecond) + }) + } + +} + +func TestServiceRegistrationHandler_dedupUpdatedWorkload(t *testing.T) { + testCases := []struct { + inputOldWorkload *serviceregistration.WorkloadServices + inputNewWorkload *serviceregistration.WorkloadServices + expectedOldOutput *serviceregistration.WorkloadServices + expectedNewOutput *serviceregistration.WorkloadServices + name string + }{ + { + inputOldWorkload: mockWorkload(), + inputNewWorkload: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "changed-redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + }, + { + Name: "changed-redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + expectedOldOutput: mockWorkload(), + expectedNewOutput: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "changed-redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + }, + { + Name: "changed-redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + name: "service names changed", + }, + { + inputOldWorkload: mockWorkload(), + inputNewWorkload: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + Tags: []string{"foo"}, + }, + { + Name: "redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + Tags: []string{"bar"}, + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + expectedOldOutput: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{}, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + expectedNewOutput: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + Tags: []string{"foo"}, + }, + { + Name: "redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + Tags: []string{"bar"}, + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + name: "tags updated", + }, + { + inputOldWorkload: mockWorkload(), + inputNewWorkload: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "dbs", + }, + { + Name: "redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "https", + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "dbs", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "https", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + expectedOldOutput: mockWorkload(), + expectedNewOutput: &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "dbs", + }, + { + Name: "redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "https", + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "dbs", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "https", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + }, + name: "canary tags updated", + }, + } + + s := &ServiceRegistrationHandler{} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOld, actualNew := s.dedupUpdatedWorkload(tc.inputOldWorkload, tc.inputNewWorkload) + require.ElementsMatch(t, tc.expectedOldOutput.Services, actualOld.Services) + require.ElementsMatch(t, tc.expectedNewOutput.Services, actualNew.Services) + }) + } +} + +func mockWorkload() *serviceregistration.WorkloadServices { + return &serviceregistration.WorkloadServices{ + AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c", + Task: "redis", + Group: "cache", + JobID: "example", + Canary: false, + Namespace: "default", + Services: []*structs.Service{ + { + Name: "redis-db", + AddressMode: structs.AddressModeHost, + PortLabel: "db", + }, + { + Name: "redis-http", + AddressMode: structs.AddressModeHost, + PortLabel: "http", + }, + }, + Ports: []structs.AllocatedPortMapping{ + { + Label: "db", + HostIP: "10.10.13.2", + Value: 23098, + }, + { + Label: "http", + HostIP: "10.10.13.2", + Value: 24098, + }, + }, + } +} + +// mockRPC mocks and tracks RPC calls made for testing. +type mockRPC struct { + + // callCounts tracks how many times each RPC method has been called. The + // lock should be used to access this. + callCounts map[string]int + l sync.RWMutex +} + +// calls returns the mapping counting the number of calls made to each RPC +// method. +func (mr *mockRPC) calls() map[string]int { + mr.l.RLock() + defer mr.l.RUnlock() + return mr.callCounts +} + +// RPC mocks the server RPCs, acting as though any request succeeds. +func (mr *mockRPC) RPC(method string, _, _ interface{}) error { + switch method { + case structs.ServiceRegistrationUpsertRPCMethod, structs.ServiceRegistrationDeleteByIDRPCMethod: + mr.l.Lock() + mr.callCounts[method]++ + mr.l.Unlock() + return nil + default: + return fmt.Errorf("unexpected RPC method: %v", method) + } +}