diff --git a/command/agent/consul/client.go b/command/agent/consul/client.go index 767fd4c343a3..b1ad827502d8 100644 --- a/command/agent/consul/client.go +++ b/command/agent/consul/client.go @@ -249,13 +249,21 @@ func (c *ServiceClient) sync() error { sdereg++ } + // Track services whose ports have changed as their checks may also + // need updating + portsChanged := make(map[string]struct{}, len(c.services)) + // Add Nomad services missing from Consul - for id, service := range c.services { - if _, ok := consulServices[id]; ok { - // Already in Consul; skipping - continue + for id, locals := range c.services { + if remotes, ok := consulServices[id]; ok { + if locals.Port == remotes.Port { + // Already exists in Consul; skip + continue + } + // Port changed, reregister it and its checks + portsChanged[id] = struct{}{} } - if err = c.client.ServiceRegister(service); err != nil { + if err = c.client.ServiceRegister(locals); err != nil { return err } sreg++ @@ -264,7 +272,7 @@ func (c *ServiceClient) sync() error { // Remove Nomad checks in Consul but unknown locally for id, check := range consulChecks { if _, ok := c.checks[id]; ok { - // Known check, skip + // Known check, leave it continue } if !isNomadService(check.ServiceID) { @@ -280,9 +288,11 @@ func (c *ServiceClient) sync() error { // Add Nomad checks missing from Consul for id, check := range c.checks { - if _, ok := consulChecks[id]; ok { - // Already in Consul; skipping - continue + if check, ok := consulChecks[id]; ok { + if _, changed := portsChanged[check.ServiceID]; !changed { + // Already in Consul and ports didn't change; skipping + continue + } } if err := c.client.CheckRegister(check); err != nil { return err @@ -291,11 +301,11 @@ func (c *ServiceClient) sync() error { // Handle starting scripts if script, ok := c.scripts[id]; ok { - // If it's already running, don't run it again - if _, running := c.runningScripts[id]; running { - continue + // If it's already running, cancel and replace + if oldScript, running := c.runningScripts[id]; running { + oldScript.cancel() } - // Not running, start and store the handle + // Start and store the handle c.runningScripts[id] = script.run() } } @@ -456,8 +466,6 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta newIDs[makeTaskServiceID(allocID, newTask.Name, s)] = s } - parseAddr := newTask.FindHostAndPortFor - // Loop over existing Service IDs to see if they have been removed or // updated. for existingID, existingSvc := range existingIDs { @@ -471,8 +479,10 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta continue } - // Service exists and wasn't updated, don't add it later - delete(newIDs, existingID) + if newSvc.PortLabel == existingSvc.PortLabel { + // Service exists and hasn't changed, don't add it later + delete(newIDs, existingID) + } // Check to see what checks were updated existingChecks := make(map[string]struct{}, len(existingSvc.Checks)) @@ -484,28 +494,9 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta for _, check := range newSvc.Checks { checkID := createCheckID(existingID, check) if _, exists := existingChecks[checkID]; exists { - // Check already exists; skip it + // Check exists, so don't remove it delete(existingChecks, checkID) - continue - } - - // New check, register it - if check.Type == structs.ServiceCheckScript { - if exec == nil { - return fmt.Errorf("driver doesn't support script checks") - } - ops.scripts = append(ops.scripts, newScriptCheck( - existingID, newTask.Name, checkID, check, exec, c.client, c.logger, c.shutdownCh)) } - host, port := parseAddr(existingSvc.PortLabel) - if check.PortLabel != "" { - host, port = parseAddr(check.PortLabel) - } - checkReg, err := createCheckReg(existingID, checkID, check, host, port) - if err != nil { - return err - } - ops.regChecks = append(ops.regChecks, checkReg) } // Remove existing checks not in updated service diff --git a/command/agent/consul/unit_test.go b/command/agent/consul/unit_test.go index 66934663b941..4882be1cc7cb 100644 --- a/command/agent/consul/unit_test.go +++ b/command/agent/consul/unit_test.go @@ -15,6 +15,12 @@ import ( "github.com/hashicorp/nomad/nomad/structs" ) +const ( + // Ports used in testTask + xPort = 1234 + yPort = 1235 +) + func testLogger() *log.Logger { if testing.Verbose() { return log.New(os.Stderr, "", log.LstdFlags) @@ -28,7 +34,10 @@ func testTask() *structs.Task { Resources: &structs.Resources{ Networks: []*structs.NetworkResource{ { - DynamicPorts: []structs.Port{{Label: "x", Value: 1234}}, + DynamicPorts: []structs.Port{ + {Label: "x", Value: xPort}, + {Label: "y", Value: yPort}, + }, }, }, }, @@ -49,12 +58,20 @@ type testFakeCtx struct { FakeConsul *fakeConsul Task *structs.Task + // Ticked whenever a script is called + execs chan int + + // If non-nil will be called by script checks ExecFunc func(ctx context.Context, cmd string, args []string) ([]byte, int, error) } // Exec implements the ScriptExecutor interface and will use an alternate // implementation t.ExecFunc if non-nil. func (t *testFakeCtx) Exec(ctx context.Context, cmd string, args []string) ([]byte, int, error) { + select { + case t.execs <- 1: + default: + } if t.ExecFunc == nil { // Default impl is just "ok" return []byte("ok"), 0, nil @@ -84,6 +101,7 @@ func setupFake() *testFakeCtx { ServiceClient: NewServiceClient(fc, testLogger()), FakeConsul: fc, Task: testTask(), + execs: make(chan int, 100), } } @@ -242,6 +260,188 @@ func TestConsul_ChangeTags(t *testing.T) { } } +// TestConsul_ChangePorts asserts that changing the ports on a service updates +// it in Consul. Since ports are part of the service ID this is a slightly +// different code path than changing tags. +func TestConsul_ChangePorts(t *testing.T) { + ctx := setupFake() + ctx.Task.Services[0].Checks = []*structs.ServiceCheck{ + { + Name: "c1", + Type: "tcp", + Interval: time.Second, + Timeout: time.Second, + PortLabel: "x", + }, + { + Name: "c2", + Type: "script", + Interval: 9000 * time.Hour, + Timeout: time.Second, + }, + { + Name: "c3", + Type: "http", + Protocol: "http", + Path: "/", + Interval: time.Second, + Timeout: time.Second, + PortLabel: "y", + }, + } + + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx); err != nil { + t.Fatalf("unexpected error registering task: %v", err) + } + + if err := ctx.syncOnce(); err != nil { + t.Fatalf("unexpected error syncing task: %v", err) + } + + if n := len(ctx.FakeConsul.services); n != 1 { + t.Fatalf("expected 1 service but found %d:\n%#v", n, ctx.FakeConsul.services) + } + + origServiceKey := "" + for k, v := range ctx.FakeConsul.services { + origServiceKey = k + if v.Name != ctx.Task.Services[0].Name { + t.Errorf("expected Name=%q != %q", ctx.Task.Services[0].Name, v.Name) + } + if !reflect.DeepEqual(v.Tags, ctx.Task.Services[0].Tags) { + t.Errorf("expected Tags=%v != %v", ctx.Task.Services[0].Tags, v.Tags) + } + if v.Port != xPort { + t.Errorf("expected Port x=%v but found: %v", xPort, v.Port) + } + } + + if n := len(ctx.FakeConsul.checks); n != 3 { + t.Fatalf("expected 3 checks but found %d:\n%#v", n, ctx.FakeConsul.checks) + } + + origTCPKey := "" + origScriptKey := "" + origHTTPKey := "" + for k, v := range ctx.FakeConsul.checks { + switch v.Name { + case "c1": + origTCPKey = k + if expected := fmt.Sprintf(":%d", xPort); v.TCP != expected { + t.Errorf("expected Port x=%v but found: %v", expected, v.TCP) + } + case "c2": + origScriptKey = k + select { + case <-ctx.execs: + if n := len(ctx.execs); n > 0 { + t.Errorf("expected 1 exec but found: %d", n+1) + } + case <-time.After(3 * time.Second): + t.Errorf("script not called in time") + } + case "c3": + origHTTPKey = k + if expected := fmt.Sprintf("http://:%d/", yPort); v.HTTP != expected { + t.Errorf("expected Port y=%v but found: %v", expected, v.HTTP) + } + default: + t.Fatalf("unexpected check: %q", v.Name) + } + } + + // Now update the PortLabel on the Service and Check c3 + origTask := ctx.Task + ctx.Task = testTask() + ctx.Task.Services[0].PortLabel = "y" + ctx.Task.Services[0].Checks = []*structs.ServiceCheck{ + { + Name: "c1", + Type: "tcp", + Interval: time.Second, + Timeout: time.Second, + PortLabel: "x", + }, + { + Name: "c2", + Type: "script", + Interval: 9000 * time.Hour, + Timeout: time.Second, + }, + { + Name: "c3", + Type: "http", + Protocol: "http", + Path: "/", + Interval: time.Second, + Timeout: time.Second, + // Removed PortLabel + }, + } + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx); err != nil { + t.Fatalf("unexpected error registering task: %v", err) + } + if err := ctx.syncOnce(); err != nil { + t.Fatalf("unexpected error syncing task: %v", err) + } + + if n := len(ctx.FakeConsul.services); n != 1 { + t.Fatalf("expected 1 service but found %d:\n%#v", n, ctx.FakeConsul.services) + } + + for k, v := range ctx.FakeConsul.services { + if k != origServiceKey { + t.Errorf("unexpected key change; was: %q -- but found %q", origServiceKey, k) + } + if v.Name != ctx.Task.Services[0].Name { + t.Errorf("expected Name=%q != %q", ctx.Task.Services[0].Name, v.Name) + } + if !reflect.DeepEqual(v.Tags, ctx.Task.Services[0].Tags) { + t.Errorf("expected Tags=%v != %v", ctx.Task.Services[0].Tags, v.Tags) + } + if v.Port != yPort { + t.Errorf("expected Port y=%v but found: %v", yPort, v.Port) + } + } + + if n := len(ctx.FakeConsul.checks); n != 3 { + t.Fatalf("expected 3 check but found %d:\n%#v", n, ctx.FakeConsul.checks) + } + + for k, v := range ctx.FakeConsul.checks { + switch v.Name { + case "c1": + if k != origTCPKey { + t.Errorf("unexpected key change for %s from %q to %q", v.Name, origTCPKey, k) + } + if expected := fmt.Sprintf(":%d", xPort); v.TCP != expected { + t.Errorf("expected Port x=%v but found: %v", expected, v.TCP) + } + case "c2": + if k != origScriptKey { + t.Errorf("unexpected key change for %s from %q to %q", v.Name, origScriptKey, k) + } + select { + case <-ctx.execs: + if n := len(ctx.execs); n > 0 { + t.Errorf("expected 1 exec but found: %d", n+1) + } + case <-time.After(3 * time.Second): + t.Errorf("script not called in time") + } + case "c3": + if k == origHTTPKey { + t.Errorf("expected %s key to change from %q", v.Name, k) + } + if expected := fmt.Sprintf("http://:%d/", yPort); v.HTTP != expected { + t.Errorf("expected Port y=%v but found: %v", expected, v.HTTP) + } + default: + t.Errorf("Unkown check: %q", k) + } + } +} + // TestConsul_RegServices tests basic service registration. func TestConsul_RegServices(t *testing.T) { ctx := setupFake()