diff --git a/changelog/fragments/1712108631-Use-policy-outputs-when-running-in-agent-mode.yaml b/changelog/fragments/1712108631-Use-policy-outputs-when-running-in-agent-mode.yaml new file mode 100644 index 000000000..f2dfbf600 --- /dev/null +++ b/changelog/fragments/1712108631-Use-policy-outputs-when-running-in-agent-mode.yaml @@ -0,0 +1,38 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: feature + +# Change summary; a 80ish characters long description of the change. +summary: Use policy outputs when running in agent-mode + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +description: | + Fleet-server will retrieve and use the output from the policy when running in agent-mode. + This allows the fleet-server to connect to multiple Elasticsearch hosts if it is successful when + connecting to the host provided at enrollment/installation. + We expect that the host provided during enrollment/installation is never removed as a valid output. + fleet-server does not persist output settings it retrieves locally so it must always be able to connect + with options specified at enrollment/installation. + +# Affected component; a word indicating the component this changeset affects. +component: + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: 3411 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: https://github.com/elastic/elastic-agent/issues/2784 diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index 6a9ec05df..3550a3848 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -39,12 +39,13 @@ const kRedacted = "[redacted]" // The env vars that `elastic-agent container` command uses are unrelated. // The agent will do all substitutions before sending fleet-server the complete config. type Config struct { - Fleet Fleet `config:"fleet"` - Output Output `config:"output"` - Inputs []Input `config:"inputs"` - Logging Logging `config:"logging"` - HTTP HTTP `config:"http"` - m sync.Mutex + Fleet Fleet `config:"fleet"` + Output Output `config:"output"` + Inputs []Input `config:"inputs"` + Logging Logging `config:"logging"` + HTTP HTTP `config:"http"` + RevisionIdx int64 `config:",ignore"` + m sync.Mutex } var deprecatedConfigOptions = map[string]string{ diff --git a/internal/pkg/config/output.go b/internal/pkg/config/output.go index 185d68542..75fecd687 100644 --- a/internal/pkg/config/output.go +++ b/internal/pkg/config/output.go @@ -5,6 +5,7 @@ package config import ( + "crypto/tls" "fmt" "net" "net/http" @@ -25,6 +26,14 @@ import ( const httpTransportLongPollTimeout = 10 * time.Minute const schemeHTTP = "http" +const ( + DefaultElasticsearchHost = "localhost:9200" + DefaultElasticsearchTimeout = 90 * time.Second + DefaultElasticsearchMaxRetries = 3 + DefaultElasticsearchMaxConnPerHost = 128 + DefaultElasticsearchMaxContentLength = 100 * 1024 * 1024 +) + var hasScheme = regexp.MustCompile(`^([a-z][a-z0-9+\-.]*)://`) // Output is the output configuration to elasticsearch. @@ -54,11 +63,11 @@ type Elasticsearch struct { // InitDefaults initializes the defaults for the configuration. func (c *Elasticsearch) InitDefaults() { c.Protocol = schemeHTTP - c.Hosts = []string{"localhost:9200"} - c.Timeout = 90 * time.Second - c.MaxRetries = 3 - c.MaxConnPerHost = 128 - c.MaxContentLength = 100 * 1024 * 1024 + c.Hosts = []string{DefaultElasticsearchHost} + c.Timeout = DefaultElasticsearchTimeout + c.MaxRetries = DefaultElasticsearchMaxRetries + c.MaxConnPerHost = DefaultElasticsearchMaxConnPerHost + c.MaxContentLength = DefaultElasticsearchMaxContentLength } // Validate ensures that the configuration is valid. @@ -173,6 +182,108 @@ func (c *Elasticsearch) ToESConfig(longPoll bool) (elasticsearch.Config, error) }, nil } +// MergeElasticsearchPolicy will merge elasticsearch settings retrieved from the fleet-server's policy into the base configuration and return the resulting config. +// ucfg.Merge and config.Config.Merge will both fail at merging configs because the verification mode is not detect as a string type value +func MergeElasticsearchFromPolicy(cfg, pol Elasticsearch) Elasticsearch { + res := Elasticsearch{ + Protocol: cfg.Protocol, + Hosts: cfg.Hosts, + Headers: cfg.Headers, + ServiceToken: cfg.ServiceToken, // ServiceToken will always be specified from the settings and not in the policy. + ServiceTokenPath: cfg.ServiceTokenPath, + ProxyURL: cfg.ProxyURL, + ProxyDisable: cfg.ProxyDisable, + ProxyHeaders: cfg.ProxyHeaders, + TLS: mergeElasticsearchTLS(cfg.TLS, pol.TLS), // tls can be a special case + MaxRetries: cfg.MaxRetries, + MaxConnPerHost: cfg.MaxConnPerHost, + Timeout: cfg.Timeout, + MaxContentLength: cfg.MaxContentLength, + } + // If policy has a non-default Hosts value use it's values for Protocol and hosts + if pol.Hosts != nil && !(len(pol.Hosts) == 1 && pol.Hosts[0] == DefaultElasticsearchHost) { + res.Protocol = pol.Protocol + res.Hosts = pol.Hosts + } + if pol.Headers != nil { + res.Headers = pol.Headers + } + // If the policy ProxyURL is set, use all of the policy's Proxy values. + if pol.ProxyURL != "" { + res.ProxyURL = pol.ProxyURL + res.ProxyDisable = pol.ProxyDisable + res.ProxyHeaders = pol.ProxyHeaders + } + if pol.MaxRetries != DefaultElasticsearchMaxRetries { + res.MaxRetries = pol.MaxRetries + } + if pol.MaxConnPerHost != DefaultElasticsearchMaxConnPerHost { + res.MaxConnPerHost = pol.MaxConnPerHost + } + if pol.Timeout != DefaultElasticsearchTimeout { + res.Timeout = pol.Timeout + } + if pol.MaxContentLength != DefaultElasticsearchMaxContentLength { + res.MaxContentLength = pol.MaxContentLength + } + return res +} + +// mergeElasticsearchTLS merges the TLS settings received from the fleet-server's policy into the settings the agent passes +func mergeElasticsearchTLS(cfg, pol *tlscommon.Config) *tlscommon.Config { + if cfg == nil && pol == nil { + return nil + } else if cfg == nil && pol != nil { + return pol + } else if cfg != nil && pol == nil { + return cfg + } + res := &tlscommon.Config{ + Enabled: cfg.Enabled, + VerificationMode: cfg.VerificationMode, + Versions: cfg.Versions, + CipherSuites: cfg.CipherSuites, + CAs: cfg.CAs, + Certificate: cfg.Certificate, + CurveTypes: cfg.CurveTypes, + Renegotiation: cfg.Renegotiation, + CASha256: cfg.CASha256, + CATrustedFingerprint: cfg.CATrustedFingerprint, + } + if pol.Enabled != nil { + res.Enabled = pol.Enabled + } + if pol.VerificationMode != tlscommon.VerifyFull { + res.VerificationMode = pol.VerificationMode // VerificationMode defaults to VerifyFull + } + if pol.Versions != nil { + res.Versions = pol.Versions + } + if pol.CipherSuites != nil { + res.CipherSuites = pol.CipherSuites + } + if pol.CAs != nil { + res.CAs = pol.CAs + } + if pol.Certificate.Certificate != "" { + res.Certificate = pol.Certificate + } + if pol.CurveTypes != nil { + res.CurveTypes = pol.CurveTypes + } + if pol.Renegotiation != tlscommon.TLSRenegotiationSupport(tls.RenegotiateNever) { + res.Renegotiation = pol.Renegotiation + } + if pol.CASha256 != nil { + res.CASha256 = pol.CASha256 + } + if pol.CATrustedFingerprint != "" { + res.CATrustedFingerprint = pol.CATrustedFingerprint + } + + return res +} + // Validate validates that only elasticsearch is defined on the output. func (c *Output) Validate() error { if c.Extra == nil { diff --git a/internal/pkg/config/output_test.go b/internal/pkg/config/output_test.go index a2fdf5cfd..8e9143c9c 100644 --- a/internal/pkg/config/output_test.go +++ b/internal/pkg/config/output_test.go @@ -382,3 +382,170 @@ func setTestEnv(t *testing.T, env map[string]string) { t.Setenv(k, v) } } + +func TestMergeElasticsearchFromPolicy(t *testing.T) { + cfg := Elasticsearch{ + Protocol: "http", + Hosts: []string{"elasticsearch:9200"}, + ServiceToken: "token", + Timeout: time.Second, + MaxRetries: 1, + MaxConnPerHost: 1, + MaxContentLength: 1, + } + tests := []struct { + name string + pol Elasticsearch + res Elasticsearch + }{{ + name: "default policy", + pol: Elasticsearch{ + Hosts: []string{"localhost:9200"}, + Timeout: DefaultElasticsearchTimeout, + MaxRetries: DefaultElasticsearchMaxRetries, + MaxConnPerHost: DefaultElasticsearchMaxConnPerHost, + MaxContentLength: DefaultElasticsearchMaxContentLength, + }, + res: Elasticsearch{ + Protocol: "http", + Hosts: []string{"elasticsearch:9200"}, + ServiceToken: "token", + Timeout: time.Second, + MaxRetries: 1, + MaxConnPerHost: 1, + MaxContentLength: 1, + }, + }, { + name: "hosts differ", + pol: Elasticsearch{ + Protocol: "https", + Hosts: []string{"elasticsearch:9200", "other:9200"}, + Timeout: DefaultElasticsearchTimeout, + MaxRetries: DefaultElasticsearchMaxRetries, + MaxConnPerHost: DefaultElasticsearchMaxConnPerHost, + MaxContentLength: DefaultElasticsearchMaxContentLength, + }, + res: Elasticsearch{ + Protocol: "https", + Hosts: []string{"elasticsearch:9200", "other:9200"}, + ServiceToken: "token", + Timeout: time.Second, + MaxRetries: 1, + MaxConnPerHost: 1, + MaxContentLength: 1, + }, + }, { + name: "all non tls attributes differ", + pol: Elasticsearch{ + Protocol: "https", + Hosts: []string{"elasticsearch:9200", "other:9200"}, + Headers: map[string]string{"custom": "value"}, + ProxyURL: "http://proxy:8080", + ProxyDisable: false, + ProxyHeaders: map[string]string{"proxyhead": "proxyval"}, + Timeout: time.Second * 2, + MaxRetries: 2, + MaxConnPerHost: 3, + MaxContentLength: 4, + }, + res: Elasticsearch{ + Protocol: "https", + Hosts: []string{"elasticsearch:9200", "other:9200"}, + Headers: map[string]string{"custom": "value"}, + ProxyURL: "http://proxy:8080", + ProxyDisable: false, + ProxyHeaders: map[string]string{"proxyhead": "proxyval"}, + ServiceToken: "token", + Timeout: 2 * time.Second, + MaxRetries: 2, + MaxConnPerHost: 3, + MaxContentLength: 4, + }, + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := MergeElasticsearchFromPolicy(cfg, tc.pol) + assert.Equal(t, tc.res.Protocol, res.Protocol) + require.Len(t, res.Hosts, len(tc.res.Hosts)) + for i, host := range tc.res.Hosts { + assert.Equalf(t, host, res.Hosts[i], "host %d does not match", i) + } + require.Len(t, res.Headers, len(tc.res.Headers)) + for k, v := range tc.res.Headers { + assert.Equal(t, v, res.Headers[k]) + } + assert.Equal(t, tc.res.ServiceToken, res.ServiceToken) + assert.Equal(t, tc.res.ServiceTokenPath, res.ServiceTokenPath) + assert.Equal(t, tc.res.ProxyURL, res.ProxyURL) + assert.Equal(t, tc.res.ProxyDisable, res.ProxyDisable) + require.Len(t, res.ProxyHeaders, len(tc.res.ProxyHeaders)) + for k, v := range tc.res.ProxyHeaders { + assert.Equal(t, v, res.ProxyHeaders[k]) + } + assert.Nil(t, res.TLS) + assert.Equal(t, tc.res.MaxRetries, res.MaxRetries) + assert.Equal(t, tc.res.MaxConnPerHost, res.MaxConnPerHost) + assert.Equal(t, tc.res.Timeout, res.Timeout) + assert.Equal(t, tc.res.MaxContentLength, res.MaxContentLength) + }) + } +} + +func TestMergeElasticsearchTLS(t *testing.T) { + enabled := true + disabled := false + t.Run("both nil", func(t *testing.T) { + res := mergeElasticsearchTLS(nil, nil) + assert.Nil(t, res) + }) + t.Run("cfg not nil", func(t *testing.T) { + res := mergeElasticsearchTLS(&tlscommon.Config{ + Enabled: &enabled, + VerificationMode: tlscommon.VerifyFull, + }, nil) + require.NotNil(t, res) + assert.True(t, *res.Enabled) + assert.Equal(t, tlscommon.VerifyFull, res.VerificationMode) + }) + t.Run("pol not nil", func(t *testing.T) { + res := mergeElasticsearchTLS(nil, &tlscommon.Config{ + Enabled: &enabled, + VerificationMode: tlscommon.VerifyFull, + }) + require.NotNil(t, res) + assert.True(t, *res.Enabled) + assert.Equal(t, tlscommon.VerifyFull, res.VerificationMode) + }) + t.Run("both not nil", func(t *testing.T) { + res := mergeElasticsearchTLS(&tlscommon.Config{ + Enabled: &disabled, + VerificationMode: tlscommon.VerifyFull, + }, &tlscommon.Config{ + Enabled: &enabled, + VerificationMode: tlscommon.VerifyCertificate, + Versions: []tlscommon.TLSVersion{tlscommon.TLSVersion13}, + CipherSuites: []tlscommon.CipherSuite{tlscommon.CipherSuite(tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA)}, + CAs: []string{"/path/to/ca.crt"}, + Certificate: tlscommon.CertificateConfig{ + Certificate: "/path/to/cert.crt", + Key: "/path/to/key.crt", + }, + CASha256: []string{"casha256val"}, + CATrustedFingerprint: "fingerprint", + }) + require.NotNil(t, res) + assert.True(t, *res.Enabled) + assert.Equal(t, tlscommon.VerifyCertificate, res.VerificationMode) + require.Len(t, res.Versions, 1) + assert.Equal(t, tlscommon.TLSVersion13, res.Versions[0]) + require.Len(t, res.CipherSuites, 1) + assert.Equal(t, tlscommon.CipherSuite(tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA), res.CipherSuites[0]) + require.Len(t, res.CAs, 1) + assert.Equal(t, "/path/to/ca.crt", res.CAs[0]) + assert.Equal(t, "/path/to/cert.crt", res.Certificate.Certificate) + assert.Equal(t, "/path/to/key.crt", res.Certificate.Key) + require.Len(t, res.CASha256, 1) + assert.Equal(t, "casha256val", res.CASha256[0]) + assert.Equal(t, "fingerprint", res.CATrustedFingerprint) + }) +} diff --git a/internal/pkg/policy/self.go b/internal/pkg/policy/self.go index 672904cc8..02b26685b 100644 --- a/internal/pkg/policy/self.go +++ b/internal/pkg/policy/self.go @@ -8,10 +8,12 @@ import ( "context" "errors" "fmt" + "strings" "sync" "time" "github.com/elastic/elastic-agent-client/v7/pkg/client" + "github.com/elastic/go-ucfg" "go.elastic.co/apm/v2" "github.com/rs/zerolog" @@ -32,6 +34,10 @@ const DefaultCheckTime = 5 * time.Second // DefaultCheckTimeout is the default timeout when checking for policies. const DefaultCheckTimeout = 30 * time.Second +const fleetserverInput = "fleet-server" + +var ErrInvalidOutput = fmt.Errorf("policy output invalid") + type enrollmentTokenFetcher func(ctx context.Context, bulker bulk.Bulk, policyID string) ([]model.EnrollmentAPIKey, error) type SelfMonitor interface { @@ -48,12 +54,14 @@ type selfMonitorT struct { fleet config.Fleet bulker bulk.Bulk monitor monitor.Monitor + cfgCh chan<- *config.Config policyID string state client.UnitState reporter state.Reporter - policy *model.Policy + policy *model.Policy + lastRev int64 policyF policyFetcher policiesIndex string @@ -67,11 +75,12 @@ type selfMonitorT struct { // // Ensures that the policy that this Fleet Server attached to exists and that it // has a Fleet Server input defined. -func NewSelfMonitor(fleet config.Fleet, bulker bulk.Bulk, monitor monitor.Monitor, policyID string, reporter state.Reporter) SelfMonitor { +func NewSelfMonitor(fleet config.Fleet, bulker bulk.Bulk, monitor monitor.Monitor, policyID string, reporter state.Reporter, cfgCh chan<- *config.Config) SelfMonitor { return &selfMonitorT{ fleet: fleet, bulker: bulker, monitor: monitor, + cfgCh: cfgCh, policyID: policyID, state: client.UnitStateStarting, reporter: reporter, @@ -174,9 +183,17 @@ func (m *selfMonitorT) processPolicies(ctx context.Context, policies []model.Pol policy := latest[i] if m.policyID != "" && policy.PolicyID == m.policyID { m.policy = &policy + err := m.sendPolicyOutput() + if err != nil { + m.log.Warn().Err(err).Int64(logger.RevisionIdx, m.lastRev).Str(logger.PolicyID, m.policyID).Msg("Failed to send fleet-server output") + } break } else if m.policyID == "" && policy.DefaultFleetServer { m.policy = &policy + err := m.sendPolicyOutput() + if err != nil { + m.log.Warn().Err(err).Int64(logger.RevisionIdx, m.lastRev).Msg("Failed to send default policy fleet-server output") + } break } } @@ -187,6 +204,89 @@ func (m *selfMonitorT) groupByLatest(policies []model.Policy) map[string]model.P return groupByLatest(policies) } +// sendPolicyOutput will parse the policy and send it through the config channel with only Output.Elasticsearch and RevisionIdx set +// It will not send to the config channel if the policy revision_idx has not changed. +// It returns any errors encountered when parsing the policy +func (m *selfMonitorT) sendPolicyOutput() error { + // policy revision has not changed + if m.policy.RevisionIdx == m.lastRev { + return nil + } + // always copy revisionIdx + m.lastRev = m.policy.RevisionIdx + + name, ok := getFleetOutputName(m.policy) + if !ok { + return fmt.Errorf("unable to find fleet-server use_output attribute") + } + data, ok := m.policy.Data.Outputs[name] + if !ok { + return fmt.Errorf("unable to find output name %q in policy", name) + } + outType, ok := data["type"].(string) + if !ok { + return fmt.Errorf("output name %s has non-string in type attribute: %w", name, ErrInvalidOutput) + } + if outType != OutputTypeElasticsearch { + return fmt.Errorf("output %s is type: %q, expected: elasticsearch", name, outType) + } + + var policyES config.Elasticsearch + output, err := ucfg.NewFrom(data, config.DefaultOptions...) + if err != nil { + return fmt.Errorf("unable to create config from output data: %w", err) + } + if err := output.Unpack(&policyES, config.DefaultOptions...); err != nil { + return fmt.Errorf("unable to unback config data to config.Elasticsearch: %w", err) + } + + // The output block in the policy may not have the schema set so we need to manually set it. + isHTTPS := false + for _, host := range policyES.Hosts { + if strings.HasPrefix(strings.ToLower(host), "https") { + isHTTPS = true + break + } + } + if isHTTPS { + policyES.Protocol = "https" + } + m.cfgCh <- &config.Config{ + Output: config.Output{ + Elasticsearch: policyES, + }, + RevisionIdx: m.lastRev, + } + return nil +} + +// getFleetOutputName returns the output name that the fleet-server input of the policy uses +func getFleetOutputName(p *model.Policy) (string, bool) { + if p.Data == nil { + return "", false + } + for _, input := range p.Data.Inputs { + val, found := input["type"] + if !found { + continue + } + typ, ok := val.(string) + if !ok { + continue + } + if typ != fleetserverInput { + continue + } + val, found = input["use_output"] + if !found { + return "", false + } + out, ok := val.(string) + return out, ok + } + return "", false +} + func (m *selfMonitorT) updateState(ctx context.Context) (client.UnitState, error) { m.mut.Lock() defer m.mut.Unlock() @@ -294,7 +394,7 @@ func HasFleetServerInput(inputs []map[string]interface{}) bool { if !ok { return false } - if attr == "fleet-server" { + if attr == fleetserverInput { return true } } diff --git a/internal/pkg/policy/self_test.go b/internal/pkg/policy/self_test.go index 34ade9fe7..1514ffac0 100644 --- a/internal/pkg/policy/self_test.go +++ b/internal/pkg/policy/self_test.go @@ -19,7 +19,9 @@ import ( "github.com/elastic/elastic-agent-client/v7/pkg/client" "github.com/gofrs/uuid" "github.com/rs/xid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/elastic/fleet-server/v7/internal/pkg/bulk" "github.com/elastic/fleet-server/v7/internal/pkg/config" @@ -54,7 +56,7 @@ func TestSelfMonitor_DefaultPolicy(t *testing.T) { emptyBulkerMap := make(map[string]bulk.Bulk) bulker.On("GetBulkerMap").Return(emptyBulkerMap) - monitor := NewSelfMonitor(cfg, bulker, mm, "", reporter) + monitor := NewSelfMonitor(cfg, bulker, mm, "", reporter, make(chan *config.Config, 2)) sm := monitor.(*selfMonitorT) sm.policyF = func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) { return []model.Policy{}, nil @@ -125,7 +127,7 @@ func TestSelfMonitor_DefaultPolicy(t *testing.T) { rId = xid.New().String() pData = model.PolicyData{Inputs: []map[string]interface{}{ { - "type": "fleet-server", + "type": fleetserverInput, }, }} policy = model.Policy{ @@ -193,7 +195,7 @@ func TestSelfMonitor_DefaultPolicy_Degraded(t *testing.T) { emptyBulkerMap := make(map[string]bulk.Bulk) bulker.On("GetBulkerMap").Return(emptyBulkerMap) - monitor := NewSelfMonitor(cfg, bulker, mm, "", reporter) + monitor := NewSelfMonitor(cfg, bulker, mm, "", reporter, make(chan *config.Config, 1)) sm := monitor.(*selfMonitorT) sm.checkTime = 100 * time.Millisecond @@ -241,7 +243,7 @@ func TestSelfMonitor_DefaultPolicy_Degraded(t *testing.T) { rId := xid.New().String() pData := model.PolicyData{Inputs: []map[string]interface{}{ { - "type": "fleet-server", + "type": fleetserverInput, }, }} policy := model.Policy{ @@ -352,7 +354,8 @@ func TestSelfMonitor_SpecificPolicy(t *testing.T) { emptyBulkerMap := make(map[string]bulk.Bulk) bulker.On("GetBulkerMap").Return(emptyBulkerMap) - monitor := NewSelfMonitor(cfg, bulker, mm, policyID, reporter) + chConfig := make(chan *config.Config, 2) + monitor := NewSelfMonitor(cfg, bulker, mm, policyID, reporter, chConfig) sm := monitor.(*selfMonitorT) sm.policyF = func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) { return []model.Policy{}, nil @@ -420,11 +423,16 @@ func TestSelfMonitor_SpecificPolicy(t *testing.T) { }, ftesting.RetrySleep(1*time.Second)) rId = xid.New().String() - pData = model.PolicyData{Inputs: []map[string]interface{}{ - { - "type": "fleet-server", + pData = model.PolicyData{ + Inputs: []map[string]interface{}{{"type": fleetserverInput, "use_output": "default"}}, + Outputs: map[string]map[string]interface{}{ + "default": map[string]interface{}{ + "type": "elasticsearch", + "hosts": []interface{}{"https://elasticsearch:9200"}, + "protocol": "https", + }, }, - }} + } policy = model.Policy{ ESDocument: model.ESDocument{ Id: rId, @@ -465,6 +473,15 @@ func TestSelfMonitor_SpecificPolicy(t *testing.T) { if merr != nil && merr != context.Canceled { t.Fatal(merr) } + + select { + case cfg := <-chConfig: + assert.Equal(t, int64(1), cfg.RevisionIdx) + require.Len(t, cfg.Output.Elasticsearch.Hosts, 1) + assert.Equal(t, "https://elasticsearch:9200", cfg.Output.Elasticsearch.Hosts[0]) + default: + t.Fatal("no policy on config channel") + } } func TestSelfMonitor_SpecificPolicy_Degraded(t *testing.T) { @@ -490,7 +507,7 @@ func TestSelfMonitor_SpecificPolicy_Degraded(t *testing.T) { emptyBulkerMap := make(map[string]bulk.Bulk) bulker.On("GetBulkerMap").Return(emptyBulkerMap) - monitor := NewSelfMonitor(cfg, bulker, mm, policyID, reporter) + monitor := NewSelfMonitor(cfg, bulker, mm, policyID, reporter, make(chan *config.Config, 1)) sm := monitor.(*selfMonitorT) sm.checkTime = 100 * time.Millisecond @@ -537,7 +554,7 @@ func TestSelfMonitor_SpecificPolicy_Degraded(t *testing.T) { rId := xid.New().String() pData := model.PolicyData{Inputs: []map[string]interface{}{ { - "type": "fleet-server", + "type": fleetserverInput, }, }} policy := model.Policy{ @@ -762,3 +779,70 @@ func TestSelfMonitor_reportOutputSkipIfNotFound(t *testing.T) { bulker.AssertExpectations(t) outputBulker.AssertExpectations(t) } + +func TestGetFleetOutputName(t *testing.T) { + tests := []struct { + name string + policy *model.Policy + found bool + outname string + }{{ + name: "found single input", + policy: &model.Policy{ + Data: &model.PolicyData{ + Inputs: []map[string]interface{}{{ + "type": fleetserverInput, + "use_output": "default", + }}, + }, + }, + found: true, + outname: "default", + }, { + name: "found multiple inputs", + policy: &model.Policy{ + Data: &model.PolicyData{ + Inputs: []map[string]interface{}{{ + "type": "system", + "use_output": "default", + }, { + "type": fleetserverInput, + "use_output": "custom", + }}, + }, + }, + found: true, + outname: "custom", + }, { + name: "use_output not found", + policy: &model.Policy{ + Data: &model.PolicyData{ + Inputs: []map[string]interface{}{{ + "type": fleetserverInput, + }}, + }, + }, + found: false, + outname: "", + }, { + name: "no match", + policy: &model.Policy{ + Data: &model.PolicyData{ + Inputs: []map[string]interface{}{{ + "type": "system", + "use_output": "custom", + }}, + }, + }, + found: false, + outname: "", + }} + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + name, ok := getFleetOutputName(tc.policy) + assert.Equal(t, tc.found, ok) + assert.Equal(t, tc.outname, name) + }) + } +} diff --git a/internal/pkg/server/agent.go b/internal/pkg/server/agent.go index 0d39a1c92..774009977 100644 --- a/internal/pkg/server/agent.go +++ b/internal/pkg/server/agent.go @@ -56,9 +56,6 @@ type Agent struct { srvCtx context.Context srvCanceller context.CancelFunc srvDone chan bool - - l sync.RWMutex - cfg *config.Config } // NewAgent returns an Agent that will gather connection information from the passed reader. @@ -88,14 +85,16 @@ func NewAgent(cliCfg *ucfg.Config, reader io.Reader, bi build.Info, reloadables func (a *Agent) Run(ctx context.Context) error { log := zerolog.Ctx(ctx) a.agent.RegisterDiagnosticHook("fleet-server config", "fleet-server's current configuration", "fleet-server.yml", "application/yml", func() []byte { - a.l.RLock() - if a.cfg == nil { - a.l.RUnlock() + if a.srv == nil { + log.Warn().Msg("Diagnostics hook failure fleet-server is nil.") + return nil + } + cfg := a.srv.GetConfig() + if cfg == nil { log.Warn().Msg("Diagnostics hook failure config is nil.") return nil } - cfg := a.cfg.Redact() - a.l.RUnlock() + cfg = cfg.Redact() p, err := yaml.Marshal(cfg) if err != nil { log.Error().Err(err).Msg("Diagnostics hook failure config unable to marshal yaml.") @@ -336,9 +335,6 @@ func (a *Agent) reconfigure(ctx context.Context) error { if err != nil { return err } - a.l.Lock() - a.cfg = cfg - a.l.Unlock() // reload the generic reloadables for _, r := range a.reloadables { diff --git a/internal/pkg/server/fleet.go b/internal/pkg/server/fleet.go index 081b8b1c4..dbfc82a13 100644 --- a/internal/pkg/server/fleet.go +++ b/internal/pkg/server/fleet.go @@ -11,11 +11,11 @@ import ( "os" "reflect" "runtime/debug" + "sync" "time" "github.com/elastic/elastic-agent-client/v7/pkg/client" - "github.com/elastic/fleet-server/v7/internal/pkg/state" - + "github.com/elastic/go-ucfg" "go.elastic.co/apm/v2" apmtransport "go.elastic.co/apm/v2/transport" @@ -30,10 +30,12 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/dl" "github.com/elastic/fleet-server/v7/internal/pkg/es" "github.com/elastic/fleet-server/v7/internal/pkg/gc" + "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/elastic/fleet-server/v7/internal/pkg/monitor" "github.com/elastic/fleet-server/v7/internal/pkg/policy" "github.com/elastic/fleet-server/v7/internal/pkg/profile" "github.com/elastic/fleet-server/v7/internal/pkg/scheduler" + "github.com/elastic/fleet-server/v7/internal/pkg/state" "github.com/elastic/fleet-server/v7/internal/pkg/ver" "github.com/hashicorp/go-version" @@ -52,6 +54,10 @@ type Fleet struct { cfgCh chan *config.Config cache cache.Cache reporter state.Reporter + + // Used for diagnostics reporting + l sync.RWMutex + cfg *config.Config } // NewFleet creates the actual fleet server service. @@ -74,6 +80,12 @@ type runFunc func(context.Context) error type runFuncCfg func(context.Context, *config.Config) error +func (f *Fleet) GetConfig() *config.Config { + f.l.RLock() + defer f.l.RUnlock() + return f.cfg +} + // Run runs the fleet server func (f *Fleet) Run(ctx context.Context, initCfg *config.Config) error { log := zerolog.Ctx(ctx) @@ -191,10 +203,59 @@ LOOP: } curCfg = newCfg + f.l.Lock() + f.cfg = curCfg + f.l.Unlock() select { - case newCfg = <-f.cfgCh: + case cfg := <-f.cfgCh: log.Info().Msg("Server configuration update") + if cfg.Inputs == nil && cfg.RevisionIdx != 0 { // cfg only contains updated output retrieved from policy + rev := cfg.RevisionIdx + esOutput := config.MergeElasticsearchFromPolicy(curCfg.Output.Elasticsearch, cfg.Output.Elasticsearch) + + // test config + cli, err := es.NewClient(ctx, + &config.Config{ + Output: config.Output{ + Elasticsearch: esOutput, + }, + }, + false, + elasticsearchOptions(curCfg.Inputs[0].Server.Instrumentation.Enabled, f.bi)..., + ) + if err != nil { + log.Warn().Int64(logger.RevisionIdx, rev).Err(err).Msg("unable to create elasticsearch client from policy output") + continue + } + remoteVersion, err := ver.CheckCompatibility(ctx, cli, f.bi.Version) + if err != nil { + // NOTE The error can indicate a bad network connection, bad TLS settings, etc. + // But if the error is an ErrElasticVersionConflict then something is very wrong + if errors.Is(err, es.ErrElasticVersionConflict) { + log.Error().Err(err).Int64(logger.RevisionIdx, rev).Interface("output", esOutput).Interface("bootstrap", curCfg.Output.Elasticsearch).Str("remote_version", remoteVersion).Msg("Elasticsearch version constraint failed for new output") + } else { + log.Warn().Err(err).Int64(logger.RevisionIdx, rev).Msg("Failed version compatibility check using output from policy") + } + continue + } + // work around to get a new cfg object based off curCfg + // we override the output with esOutput and have a complete config with a new mutex + tmp, err := ucfg.NewFrom(curCfg, config.DefaultOptions...) + if err != nil { + log.Error().Err(err).Int64(logger.RevisionIdx, rev).Msg("Unable to convert config") + continue + } + err = tmp.Unpack(cfg, config.DefaultOptions...) + if err != nil { + log.Error().Err(err).Int64(logger.RevisionIdx, rev).Msg("Unable to unpack config") + continue + } + log.Info().Int64(logger.RevisionIdx, rev).Msg("Using output from policy") + cfg.Output.Elasticsearch = esOutput + cfg.RevisionIdx = rev + } + newCfg = cfg case err := <-ech: f.reporter.UpdateState(client.UnitStateFailed, fmt.Sprintf("Error - %s", err), nil) //nolint:errcheck // unclear on what should we do if updating the status fails? log.Error().Err(err).Msg("Fleet Server failed") @@ -489,7 +550,7 @@ func (f *Fleet) runSubsystems(ctx context.Context, cfg *config.Config, g *errgro if f.standAlone { sm = policy.NewStandAloneSelfMonitor(bulker, f.reporter) } else { - sm = policy.NewSelfMonitor(cfg.Fleet, bulker, pim, cfg.Inputs[0].Policy.ID, f.reporter) + sm = policy.NewSelfMonitor(cfg.Fleet, bulker, pim, cfg.Inputs[0].Policy.ID, f.reporter, f.cfgCh) } g.Go(loggedRunFunc(ctx, "Policy self monitor", sm.Run)) diff --git a/internal/pkg/server/fleet_integration_test.go b/internal/pkg/server/fleet_integration_test.go index 1b8b117c9..115ec0e9b 100644 --- a/internal/pkg/server/fleet_integration_test.go +++ b/internal/pkg/server/fleet_integration_test.go @@ -19,6 +19,7 @@ import ( "net/http/httptest" "path" "strings" + "sync/atomic" "testing" "time" @@ -27,6 +28,7 @@ import ( "github.com/gofrs/uuid" "github.com/google/go-cmp/cmp" "github.com/hashicorp/go-cleanhttp" + "github.com/rs/zerolog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -48,7 +50,7 @@ const ( serverVersion = "8.0.0" localhost = "localhost" - testWaitServerUp = 3 * time.Second + testWaitServerUp = 10 * time.Second enrollBody = `{ "type": "PERMANENT", @@ -71,6 +73,9 @@ type tserver struct { srv *Fleet enrollKey string bulker bulk.Bulk + + outputReloadSuccess atomic.Int32 + outputReloadFailure atomic.Int32 } func (s *tserver) baseURL() string { @@ -212,11 +217,21 @@ func startTestServer(t *testing.T, ctx context.Context, policyD model.PolicyData g, ctx := errgroup.WithContext(ctx) + // Since we start the server in agent mode we need a way to detect if the policy monitor has reloaded the output + // NOTE: This code is brittle as it depends on a log string message match + tsrv := &tserver{cfg: cfg, g: g, srv: srv, enrollKey: key.Token(), bulker: bulker} + ctx = testlog.SetLogger(t).Hook(zerolog.HookFunc(func(e *zerolog.Event, level zerolog.Level, message string) { + if level == zerolog.InfoLevel && message == "Using output from policy" { + tsrv.outputReloadSuccess.Add(1) + } else if level == zerolog.WarnLevel && message == "Failed version compatibility check using output from policy" { + tsrv.outputReloadFailure.Add(1) + } + })).WithContext(ctx) + g.Go(func() error { return srv.Run(ctx, cfg) }) - tsrv := &tserver{cfg: cfg, g: g, srv: srv, enrollKey: key.Token(), bulker: bulker} err = tsrv.waitServerUp(ctx, testWaitServerUp) if err != nil { return nil, fmt.Errorf("unable to start server: %w", err) @@ -254,13 +269,14 @@ func (s *tserver) waitServerUp(ctx context.Context, dur time.Duration) error { return status.Status == "HEALTHY", nil } + // Wait for the server to be in a healthy state after for { healthy, err := isHealthy() if err != nil { return err } if healthy { - return nil + break } select { @@ -269,6 +285,7 @@ func (s *tserver) waitServerUp(ctx context.Context, dur time.Duration) error { case <-time.After(100 * time.Millisecond): } } + return nil } func (s *tserver) buildURL(id string, cmd string) string { @@ -303,7 +320,6 @@ func TestServerConfigErrorReload(t *testing.T) { require.NoError(t, err) logger.Init(cfg, "fleet-server") //nolint:errcheck // test logging setup - ctx = testlog.SetLogger(t).WithContext(ctx) bulker := ftesting.SetupBulk(ctx, t) policyID := uuid.Must(uuid.NewV4()).String() @@ -376,9 +392,14 @@ func TestServerConfigErrorReload(t *testing.T) { mReporter.On("UpdateState", client.UnitStateConfiguring, mock.Anything, mock.Anything).Return(nil) mReporter.On("UpdateState", client.UnitStateHealthy, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { // Call cancel to stop the server once it's healthy - cancel() + go func() { + // FIXME: A short delay is needed here as the mock failure call on line 388 is not being detected correctly in tests + time.Sleep(100 * time.Millisecond) + cancel() + }() }).Return(nil) mReporter.On("UpdateState", client.UnitStateStopping, mock.Anything, mock.Anything).Return(nil) + mReporter.On("UpdateState", client.UnitStateFailed, mock.MatchedBy(func(err error) bool { return errors.Is(err, context.Canceled) }), mock.Anything).Return(nil).Maybe() // set bad config cfg.Output.Elasticsearch.ServiceToken = "incorrect" @@ -398,6 +419,66 @@ func TestServerConfigErrorReload(t *testing.T) { mReporter.AssertExpectations(t) } +func TestServerReloadOutputOnly(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start test server + srv, err := startTestServer(t, ctx, policyData) + require.NoError(t, err) + + // Give an output that will not work - it should not use this + cfg := config.Config{ + Output: config.Output{ + Elasticsearch: config.Elasticsearch{ + Protocol: "http", + Hosts: []string{ + "http://fake:9200", + }, + }, + }, + RevisionIdx: 2, + } + err = srv.srv.Reload(ctx, &cfg) + require.NoError(t, err) + + for i := 0; i < 5; i++ { + if srv.outputReloadFailure.Load() > 0 { + break + } + time.Sleep(time.Second) + } + require.NotZero(t, srv.outputReloadFailure.Load(), "Did not detect elasticsearch output client failure") + + // Give an output that works + cfg = config.Config{ + Output: config.Output{ + Elasticsearch: config.Elasticsearch{ + Protocol: "http", + Hosts: []string{ + "http://localhost:9200", + "http://other:9200", + }, + }, + }, + RevisionIdx: 3, + } + + successes := srv.outputReloadSuccess.Load() + err = srv.srv.Reload(ctx, &cfg) + require.NoError(t, err) + for i := 0; i < 5; i++ { + if srv.outputReloadSuccess.Load() > successes { + break + } + time.Sleep(time.Second) + } + require.Greater(t, srv.outputReloadSuccess.Load(), successes, "Did not detect elasticsearch output client success") + + cancel() + srv.waitExit() //nolint:errcheck // test case +} + func TestServerUnauthorized(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -405,7 +486,6 @@ func TestServerUnauthorized(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) agentID := uuid.Must(uuid.NewV4()).String() cli := cleanhttp.DefaultClient() @@ -424,7 +504,6 @@ func TestServerUnauthorized(t *testing.T) { // Not sure if this is right response, just capturing what we have so far // TODO: revisit error response format t.Run("no auth header", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) for _, u := range allurls { req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer([]byte("{}"))) if err != nil { @@ -454,7 +533,6 @@ func TestServerUnauthorized(t *testing.T) { // Unauthorized, expecting error from /_security/_authenticate t.Run("unauthorized", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) for _, u := range agenturls { req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewBuffer([]byte("{}"))) require.NoError(t, err) @@ -506,7 +584,6 @@ func TestServerInstrumentation(t *testing.T) { // Start test server with instrumentation disabled srv, err := startTestServer(t, ctx, policyData, WithAPM(server.URL, false)) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) agentID := "1e4954ce-af37-4731-9f4a-407b08e69e42" checkinURL := srv.buildURL(agentID, "checkin") @@ -591,7 +668,6 @@ func Test_SmokeTest_Agent_Calls(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) cli := cleanhttp.DefaultClient() @@ -698,6 +774,9 @@ func Test_SmokeTest_Agent_Calls(t *testing.T) { // When decoding to a (typed) struct, the default will implicitly be false if it's missing _, ok = ackObj["errors"] require.Falsef(t, ok, "expected response to have no errors attribute, errors are present: %+v", ackObj) + + cancel() + srv.waitExit() //nolint:errcheck // test case } func EnrollAgent(enrollBody string, t *testing.T, ctx context.Context, srv *tserver) (string, string) { @@ -752,7 +831,6 @@ func Test_Agent_Enrollment_Id(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) t.Log("Enroll the first agent with enrollment_id") firstAgentID, _ := EnrollAgent(enrollBodyWEnrollmentID, t, ctx, srv) @@ -780,6 +858,9 @@ func Test_Agent_Enrollment_Id(t *testing.T) { } else { t.Fatal("duplicate agent found after enrolling with same enrollment id") } + + cancel() + srv.waitExit() //nolint:errcheck // test case } func Test_Agent_Enrollment_Id_Invalidated_API_key(t *testing.T) { @@ -799,7 +880,6 @@ func Test_Agent_Enrollment_Id_Invalidated_API_key(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) t.Log("Enroll the first agent with enrollment_id") firstAgentID, _ := EnrollAgent(enrollBodyWEnrollmentID, t, ctx, srv) @@ -839,6 +919,9 @@ func Test_Agent_Enrollment_Id_Invalidated_API_key(t *testing.T) { } else { t.Fatal("duplicate agent found after enrolling with same enrollment id") } + + cancel() + srv.waitExit() //nolint:errcheck // test case } func Test_Agent_Auth_errors(t *testing.T) { @@ -848,7 +931,6 @@ func Test_Agent_Auth_errors(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) cli := cleanhttp.DefaultClient() @@ -887,7 +969,6 @@ func Test_Agent_Auth_errors(t *testing.T) { require.NotEmpty(t, id) t.Run("use enroll key for checkin", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/"+id+"/checkin", strings.NewReader(checkinBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -900,7 +981,6 @@ func Test_Agent_Auth_errors(t *testing.T) { require.Equal(t, http.StatusNotFound, res.StatusCode) // NOTE this is a 404 and not a 400 }) t.Run("wrong agent ID", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/bad-agent-id/checkin", strings.NewReader(checkinBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+key) @@ -913,7 +993,6 @@ func Test_Agent_Auth_errors(t *testing.T) { require.Equal(t, http.StatusForbidden, res.StatusCode) }) t.Run("use another agent's api key", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -953,7 +1032,6 @@ func Test_Agent_Auth_errors(t *testing.T) { require.Equal(t, http.StatusForbidden, res.StatusCode) }) t.Run("use api key for enrollment", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+key) @@ -964,6 +1042,9 @@ func Test_Agent_Auth_errors(t *testing.T) { res.Body.Close() require.Equal(t, http.StatusInternalServerError, res.StatusCode) }) + + cancel() + srv.waitExit() //nolint:errcheck // test case } func Test_Agent_request_errors(t *testing.T) { @@ -973,11 +1054,9 @@ func Test_Agent_request_errors(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) cli := cleanhttp.DefaultClient() t.Run("no auth", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("User-Agent", "elastic agent "+serverVersion) @@ -988,7 +1067,6 @@ func Test_Agent_request_errors(t *testing.T) { require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("bad path", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/temporary", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -1000,7 +1078,6 @@ func Test_Agent_request_errors(t *testing.T) { require.Equal(t, http.StatusNotFound, res.StatusCode) }) t.Run("wrong method", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "PUT", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -1012,7 +1089,6 @@ func Test_Agent_request_errors(t *testing.T) { require.Equal(t, http.StatusMethodNotAllowed, res.StatusCode) }) t.Run("no body", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", nil) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -1024,7 +1100,6 @@ func Test_Agent_request_errors(t *testing.T) { require.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("no user agent", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -1035,7 +1110,6 @@ func Test_Agent_request_errors(t *testing.T) { require.Equal(t, http.StatusBadRequest, res.StatusCode) }) t.Run("bad user agent", func(t *testing.T) { - ctx := testlog.SetLogger(t).WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) require.NoError(t, err) req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) @@ -1046,6 +1120,9 @@ func Test_Agent_request_errors(t *testing.T) { res.Body.Close() require.Equal(t, http.StatusBadRequest, res.StatusCode) }) + + cancel() + srv.waitExit() //nolint:errcheck // test case } func Test_SmokeTest_CheckinPollTimeout(t *testing.T) { @@ -1055,7 +1132,6 @@ func Test_SmokeTest_CheckinPollTimeout(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) cli := cleanhttp.DefaultClient() @@ -1187,7 +1263,6 @@ func Test_SmokeTest_CheckinPollShutdown(t *testing.T) { // Start test server srv, err := startTestServer(t, ctx, policyData) require.NoError(t, err) - ctx = testlog.SetLogger(t).WithContext(ctx) cli := cleanhttp.DefaultClient() @@ -1291,4 +1366,7 @@ func Test_SmokeTest_CheckinPollShutdown(t *testing.T) { err = json.Unmarshal(p, &checkinResponse) require.NoError(t, err) require.Equal(t, token, *checkinResponse.AckToken) + + cancel() + srv.waitExit() //nolint:errcheck // test case } diff --git a/internal/pkg/server/fleet_secrets_integration_test.go b/internal/pkg/server/fleet_secrets_integration_test.go index e3daddc1f..65d5dcf3c 100644 --- a/internal/pkg/server/fleet_secrets_integration_test.go +++ b/internal/pkg/server/fleet_secrets_integration_test.go @@ -230,4 +230,7 @@ func Test_Agent_Policy_Secrets(t *testing.T) { "package_var_secret": "secret_value", "type": "fleet-server", }, input) + + cancel() + srv.waitExit() //nolint:errcheck // test case } diff --git a/internal/pkg/server/remote_es_output_integration_test.go b/internal/pkg/server/remote_es_output_integration_test.go index c4478f052..330686f20 100644 --- a/internal/pkg/server/remote_es_output_integration_test.go +++ b/internal/pkg/server/remote_es_output_integration_test.go @@ -249,6 +249,8 @@ func Test_Agent_Remote_ES_Output(t *testing.T) { verifyRemoteAPIKey(t, ctx, apiKeyID, true) + cancel() + srv.waitExit() //nolint:errcheck // test case } func verifyRemoteAPIKey(t *testing.T, ctx context.Context, apiKeyID string, invalidated bool) { @@ -392,6 +394,8 @@ func Test_Agent_Remote_ES_Output_ForceUnenroll(t *testing.T) { t.Log("Verify that remote API key is invalidated") verifyRemoteAPIKey(t, ctx, apiKeyID, true) + cancel() + srv.waitExit() //nolint:errcheck // test case } func Test_Agent_Remote_ES_Output_Unenroll(t *testing.T) { @@ -508,4 +512,6 @@ func Test_Agent_Remote_ES_Output_Unenroll(t *testing.T) { t.Log("Verify that remote API key is invalidated") verifyRemoteAPIKey(t, ctx, apiKeyID, true) + cancel() + srv.waitExit() //nolint:errcheck // test case } diff --git a/internal/pkg/testing/log/log.go b/internal/pkg/testing/log/log.go index 386bdc6bf..27ac3d759 100644 --- a/internal/pkg/testing/log/log.go +++ b/internal/pkg/testing/log/log.go @@ -14,7 +14,7 @@ import ( // loggest is set to debug level func SetLogger(tb testing.TB) zerolog.Logger { tb.Helper() - tw := zerolog.TestWriter{T: tb, Frame: 4} - log := zerolog.New(tw).Level(zerolog.DebugLevel) + tw := zerolog.TestWriter{T: tb, Frame: 5} + log := zerolog.New(zerolog.SyncWriter(tw)).Level(zerolog.DebugLevel) return log } diff --git a/internal/pkg/testing/setup.go b/internal/pkg/testing/setup.go index 69a7facee..61a9de97a 100644 --- a/internal/pkg/testing/setup.go +++ b/internal/pkg/testing/setup.go @@ -29,7 +29,7 @@ var defaultCfgData = []byte(` output: elasticsearch: hosts: '${ELASTICSEARCH_HOSTS:localhost:9200}' - service_token: '${ELASTICSEARCH_SERVICE_TOKEN}' + service_token: '${ELASTICSEARCH_SERVICE_TOKEN:test}' fleet: agent: id: 1e4954ce-af37-4731-9f4a-407b08e69e42