diff --git a/internal/agent/agent.go b/internal/agent/agent.go index fcc0f4c8..de7e05f2 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -40,7 +40,8 @@ type Config struct { // NewAgent returns a new instance of Agent with the given configuration func NewAgent(config *Config) (*Agent, error) { - collectorClient := collector.NewCollectorClient(config.DiscoveriesConfig.CollectorConfig, http.DefaultClient) + agentClient := http.Client{Timeout: 30 * time.Second} + collectorClient := collector.NewCollectorClient(config.DiscoveriesConfig.CollectorConfig, &agentClient) discoveries := []discovery.Discovery{ discovery.NewClusterDiscovery(collectorClient, *config.DiscoveriesConfig), @@ -137,7 +138,7 @@ func (a *Agent) Stop(ctxCancel context.CancelFunc) { func (a *Agent) startDiscoverTicker(ctx context.Context, d discovery.Discovery) { tick := func() { - result, err := d.Discover() + result, err := d.Discover(ctx) if err != nil { result = fmt.Sprintf("Error while running discovery '%s': %s", d.GetID(), err) log.Errorln(result) @@ -150,7 +151,7 @@ func (a *Agent) startDiscoverTicker(ctx context.Context, d discovery.Discovery) func (a *Agent) startHeartbeatTicker(ctx context.Context) { tick := func() { - err := a.collectorClient.Heartbeat() + err := a.collectorClient.Heartbeat(ctx) if err != nil { log.Errorf("Error while sending the heartbeat to the server: %s", err) } diff --git a/internal/core/cloud/aws.go b/internal/core/cloud/aws.go index d1984b3a..6ae6048f 100644 --- a/internal/core/cloud/aws.go +++ b/internal/core/cloud/aws.go @@ -7,6 +7,7 @@ Based on package cloud import ( + "context" "encoding/json" "fmt" "io" @@ -52,7 +53,7 @@ type Placement struct { Region string `json:"region"` } -func NewAWSMetadata(client HTTPClient) (*AWSMetadata, error) { +func NewAWSMetadata(ctx context.Context, client HTTPClient) (*AWSMetadata, error) { var err error awsMetadata := &AWSMetadata{ AmiID: "", @@ -86,7 +87,7 @@ func NewAWSMetadata(client HTTPClient) (*AWSMetadata, error) { } firstElementsList := []string{fmt.Sprintf("%s/", awsMetadataResource)} - metadata, err := buildAWSMetadata(client, awsMetadataURL, firstElementsList) + metadata, err := buildAWSMetadata(ctx, client, awsMetadataURL, firstElementsList) if err != nil { return nil, err } @@ -104,13 +105,18 @@ func NewAWSMetadata(client HTTPClient) (*AWSMetadata, error) { return awsMetadata, err } -func buildAWSMetadata(client HTTPClient, url string, elements []string) (map[string]interface{}, error) { +func buildAWSMetadata( + ctx context.Context, + client HTTPClient, + url string, + elements []string, +) (map[string]interface{}, error) { metadata := make(map[string]interface{}) for _, element := range elements { newURL := url + element - response, err := requestMetadata(client, newURL) + response, err := requestMetadata(ctx, client, newURL) if err != nil { return metadata, err } @@ -119,7 +125,7 @@ func buildAWSMetadata(client HTTPClient, url string, elements []string) (map[str currentElement := strings.Trim(element, "/") newElements := strings.Split(fmt.Sprintf("%v", response), "\n") - metadata[currentElement], err = buildAWSMetadata(client, newURL, newElements) + metadata[currentElement], err = buildAWSMetadata(ctx, client, newURL, newElements) if err != nil { return nil, err } @@ -131,8 +137,8 @@ func buildAWSMetadata(client HTTPClient, url string, elements []string) (map[str return metadata, nil } -func requestMetadata(client HTTPClient, url string) (interface{}, error) { - req, _ := http.NewRequest(http.MethodGet, url, nil) +func requestMetadata(ctx context.Context, client HTTPClient, url string) (interface{}, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) resp, err := client.Do(req) if err != nil { diff --git a/internal/core/cloud/aws_test.go b/internal/core/cloud/aws_test.go index 42b4aa15..8504e43d 100644 --- a/internal/core/cloud/aws_test.go +++ b/internal/core/cloud/aws_test.go @@ -2,6 +2,7 @@ package cloud_test import ( "bytes" + "context" "io" "net/http" "os" @@ -24,6 +25,7 @@ func TestAWSMetadataTestSuite(t *testing.T) { } func (suite *AWSMetadataTestSuite) TestNewAWSMetadata() { + ctx := context.TODO() clientMock := new(mocks.HTTPClient) fixtures := []string{ @@ -65,7 +67,7 @@ func (suite *AWSMetadataTestSuite) TestNewAWSMetadata() { ).Once() } - m, err := cloud.NewAWSMetadata(clientMock) + m, err := cloud.NewAWSMetadata(ctx, clientMock) suite.NoError(err) diff --git a/internal/core/cloud/azure.go b/internal/core/cloud/azure.go index dd4c8c1e..4a14d24b 100644 --- a/internal/core/cloud/azure.go +++ b/internal/core/cloud/azure.go @@ -7,6 +7,7 @@ package cloud import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -142,7 +143,7 @@ type Subnet struct { Prefix string `json:"prefix,omitempty"` } -func NewAzureMetadata(client HTTPClient) (*AzureMetadata, error) { +func NewAzureMetadata(ctx context.Context, client HTTPClient) (*AzureMetadata, error) { var err error m := &AzureMetadata{ Compute: Compute{ @@ -220,7 +221,12 @@ func NewAzureMetadata(client HTTPClient) (*AzureMetadata, error) { }, } - req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s/metadata/instance", azureAPIAddress), nil) + req, _ := http.NewRequestWithContext( + ctx, + http.MethodGet, + fmt.Sprintf("http://%s/metadata/instance", azureAPIAddress), + nil, + ) req.Header.Add("Metadata", "True") q := req.URL.Query() diff --git a/internal/core/cloud/azure_test.go b/internal/core/cloud/azure_test.go index 732719b8..2a7c7ec1 100644 --- a/internal/core/cloud/azure_test.go +++ b/internal/core/cloud/azure_test.go @@ -2,6 +2,7 @@ package cloud_test import ( "bytes" + "context" "io" "net/http" "os" @@ -23,6 +24,7 @@ func TestAzureMetadataTestSuite(t *testing.T) { } func (suite *AzureMetadataTestSuite) TestNewAzureMetadata() { + ctx := context.TODO() clientMock := new(mocks.HTTPClient) aFile, _ := os.Open(helpers.GetFixturePath("discovery/azure/azure_metadata.json")) @@ -38,7 +40,7 @@ func (suite *AzureMetadataTestSuite) TestNewAzureMetadata() { response, nil, ) - m, err := cloud.NewAzureMetadata(clientMock) + m, err := cloud.NewAzureMetadata(ctx, clientMock) expectedMeta := &cloud.AzureMetadata{ Compute: cloud.Compute{ diff --git a/internal/core/cloud/gcp.go b/internal/core/cloud/gcp.go index 04feb998..96c3d3a4 100644 --- a/internal/core/cloud/gcp.go +++ b/internal/core/cloud/gcp.go @@ -8,6 +8,7 @@ package cloud import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -47,7 +48,7 @@ type GCPProject struct { ProjectID string `json:"projectId,omitempty"` } -func NewGCPMetadata(client HTTPClient) (*GCPMetadata, error) { +func NewGCPMetadata(ctx context.Context, client HTTPClient) (*GCPMetadata, error) { var err error m := &GCPMetadata{ Instance: GCPInstance{ @@ -63,7 +64,7 @@ func NewGCPMetadata(client HTTPClient) (*GCPMetadata, error) { }, } - req, _ := http.NewRequest(http.MethodGet, gcpMetadataURL, nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, gcpMetadataURL, nil) req.Header.Add("Metadata-Flavor", gcpMetadataFlavorHeader) q := req.URL.Query() diff --git a/internal/core/cloud/gcp_test.go b/internal/core/cloud/gcp_test.go index 4bd3c6e8..fd1ecbba 100644 --- a/internal/core/cloud/gcp_test.go +++ b/internal/core/cloud/gcp_test.go @@ -2,6 +2,7 @@ package cloud_test import ( "bytes" + "context" "io" "net/http" "os" @@ -23,6 +24,7 @@ func TestGcpMetadataTestSuite(t *testing.T) { } func (suite *GcpMetadataTestSuite) TestNewGCPMetadata() { + ctx := context.TODO() clientMock := new(mocks.HTTPClient) aFile, _ := os.Open(helpers.GetFixturePath("discovery/gcp/gcp_metadata.json")) @@ -38,7 +40,7 @@ func (suite *GcpMetadataTestSuite) TestNewGCPMetadata() { response, nil, ) - m, err := cloud.NewGCPMetadata(clientMock) + m, err := cloud.NewGCPMetadata(ctx, clientMock) expectedMeta := &cloud.GCPMetadata{ Instance: cloud.GCPInstance{ diff --git a/internal/core/cloud/metadata.go b/internal/core/cloud/metadata.go index 36a368e5..a053cdc7 100644 --- a/internal/core/cloud/metadata.go +++ b/internal/core/cloud/metadata.go @@ -1,6 +1,7 @@ package cloud import ( + "context" "regexp" "strings" @@ -181,7 +182,11 @@ func (i *Identifier) IdentifyCloudProvider() (string, error) { return "", nil } -func NewCloudInstance(commandExecutor utils.CommandExecutor, client HTTPClient) (*Instance, error) { +func NewCloudInstance( + ctx context.Context, + commandExecutor utils.CommandExecutor, + client HTTPClient, +) (*Instance, error) { var err error var cloudMetadata interface{} @@ -200,14 +205,14 @@ func NewCloudInstance(commandExecutor utils.CommandExecutor, client HTTPClient) switch provider { case Azure: { - cloudMetadata, err = NewAzureMetadata(client) + cloudMetadata, err = NewAzureMetadata(ctx, client) if err != nil { return nil, err } } case AWS: { - awsMetadata, err := NewAWSMetadata(client) + awsMetadata, err := NewAWSMetadata(ctx, client) if err != nil { return nil, err } @@ -215,7 +220,7 @@ func NewCloudInstance(commandExecutor utils.CommandExecutor, client HTTPClient) } case GCP: { - gcpMetadata, err := NewGCPMetadata(client) + gcpMetadata, err := NewGCPMetadata(ctx, client) if err != nil { return nil, err } diff --git a/internal/core/cloud/metadata_test.go b/internal/core/cloud/metadata_test.go index 481c8d06..24d6d313 100644 --- a/internal/core/cloud/metadata_test.go +++ b/internal/core/cloud/metadata_test.go @@ -2,6 +2,7 @@ package cloud_test import ( "bytes" + "context" "io" "net/http" "testing" @@ -238,6 +239,8 @@ func (suite *CloudMetadataTestSuite) TestIdentifyCloudProviderNoCloud() { } func (suite *CloudMetadataTestSuite) TestNewCloudInstanceAzure() { + ctx := context.TODO() + suite.mockExecutor. On("Exec", "dmidecode", "-s", "chassis-asset-tag"). Return(dmidecodeAzure(), nil) @@ -253,7 +256,7 @@ func (suite *CloudMetadataTestSuite) TestNewCloudInstanceAzure() { response, nil, ) - c, err := cloud.NewCloudInstance(suite.mockExecutor, suite.mockHTTPClient) + c, err := cloud.NewCloudInstance(ctx, suite.mockExecutor, suite.mockHTTPClient) suite.NoError(err) suite.Equal("azure", c.Provider) @@ -263,6 +266,7 @@ func (suite *CloudMetadataTestSuite) TestNewCloudInstanceAzure() { } func (suite *CloudMetadataTestSuite) TestNewCloudInstanceAWS() { + ctx := context.TODO() suite.mockExecutor. On("Exec", "dmidecode", "-s", "chassis-asset-tag"). Return(dmidecodeEmpty(), nil). @@ -289,7 +293,7 @@ func (suite *CloudMetadataTestSuite) TestNewCloudInstanceAWS() { On("Do", mock.AnythingOfType("*http.Request")). Return(response2, nil) - c, err := cloud.NewCloudInstance(suite.mockExecutor, suite.mockHTTPClient) + c, err := cloud.NewCloudInstance(ctx, suite.mockExecutor, suite.mockHTTPClient) suite.NoError(err) suite.Equal("aws", c.Provider) @@ -299,6 +303,7 @@ func (suite *CloudMetadataTestSuite) TestNewCloudInstanceAWS() { } func (suite *CloudMetadataTestSuite) TestNewInstanceNutanix() { + ctx := context.TODO() suite.mockExecutor. On("Exec", "dmidecode", "-s", "chassis-asset-tag"). Return(dmidecodeEmpty(), nil). @@ -311,7 +316,7 @@ func (suite *CloudMetadataTestSuite) TestNewInstanceNutanix() { On("Exec", "dmidecode"). Return(dmidecodeNutanix(), nil) - c, err := cloud.NewCloudInstance(suite.mockExecutor, suite.mockHTTPClient) + c, err := cloud.NewCloudInstance(ctx, suite.mockExecutor, suite.mockHTTPClient) suite.NoError(err) suite.Equal("nutanix", c.Provider) @@ -320,6 +325,7 @@ func (suite *CloudMetadataTestSuite) TestNewInstanceNutanix() { } func (suite *CloudMetadataTestSuite) TestNewInstanceKVM() { + ctx := context.TODO() suite.mockExecutor. On("Exec", "dmidecode", "-s", "chassis-asset-tag"). Return(dmidecodeEmpty(), nil). @@ -334,7 +340,7 @@ func (suite *CloudMetadataTestSuite) TestNewInstanceKVM() { On("Exec", "systemd-detect-virt"). Return(systemdDetectVirtKVM(), nil) - c, err := cloud.NewCloudInstance(suite.mockExecutor, suite.mockHTTPClient) + c, err := cloud.NewCloudInstance(ctx, suite.mockExecutor, suite.mockHTTPClient) suite.NoError(err) suite.Equal("kvm", c.Provider) @@ -343,6 +349,7 @@ func (suite *CloudMetadataTestSuite) TestNewInstanceKVM() { } func (suite *CloudMetadataTestSuite) TestNewCloudInstanceNoCloud() { + ctx := context.TODO() suite.mockExecutor. On("Exec", "dmidecode", "-s", "chassis-asset-tag"). Return(dmidecodeEmpty(), nil). @@ -357,7 +364,7 @@ func (suite *CloudMetadataTestSuite) TestNewCloudInstanceNoCloud() { On("Exec", "systemd-detect-virt"). Return(systemdDetectVirtEmpty(), nil) - c, err := cloud.NewCloudInstance(suite.mockExecutor, suite.mockHTTPClient) + c, err := cloud.NewCloudInstance(ctx, suite.mockExecutor, suite.mockHTTPClient) suite.NoError(err) suite.Equal("", c.Provider) diff --git a/internal/core/sapsystem/sapcontrol.go b/internal/core/sapsystem/sapcontrol.go index 938c6de1..40f9e860 100644 --- a/internal/core/sapsystem/sapcontrol.go +++ b/internal/core/sapsystem/sapcontrol.go @@ -1,6 +1,7 @@ package sapsystem import ( + "context" "fmt" "github.com/pkg/errors" @@ -13,18 +14,18 @@ type SAPControl struct { Properties []*sapcontrol.InstanceProperty } -func NewSAPControl(w sapcontrol.WebService) (*SAPControl, error) { - properties, err := w.GetInstanceProperties() +func NewSAPControl(ctx context.Context, w sapcontrol.WebService) (*SAPControl, error) { + properties, err := w.GetInstanceProperties(ctx) if err != nil { return nil, errors.Wrap(err, "SAPControl web service error") } - processes, err := w.GetProcessList() + processes, err := w.GetProcessList(ctx) if err != nil { return nil, errors.Wrap(err, "SAPControl web service error") } - instances, err := w.GetSystemInstanceList() + instances, err := w.GetSystemInstanceList(ctx) if err != nil { return nil, errors.Wrap(err, "SAPControl web service error") } diff --git a/internal/core/sapsystem/sapcontrolapi/mocks/WebService.go b/internal/core/sapsystem/sapcontrolapi/mocks/WebService.go index 37fd3a1d..a6c308f6 100644 --- a/internal/core/sapsystem/sapcontrolapi/mocks/WebService.go +++ b/internal/core/sapsystem/sapcontrolapi/mocks/WebService.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.12.3. DO NOT EDIT. +// Code generated by mockery v2.32.3. DO NOT EDIT. package mocks import ( + context "context" + mock "github.com/stretchr/testify/mock" sapcontrolapi "github.com/trento-project/agent/internal/core/sapsystem/sapcontrolapi" ) @@ -12,22 +14,25 @@ type WebService struct { mock.Mock } -// GetInstanceProperties provides a mock function with given fields: -func (_m *WebService) GetInstanceProperties() (*sapcontrolapi.GetInstancePropertiesResponse, error) { - ret := _m.Called() +// GetInstanceProperties provides a mock function with given fields: ctx +func (_m *WebService) GetInstanceProperties(ctx context.Context) (*sapcontrolapi.GetInstancePropertiesResponse, error) { + ret := _m.Called(ctx) var r0 *sapcontrolapi.GetInstancePropertiesResponse - if rf, ok := ret.Get(0).(func() *sapcontrolapi.GetInstancePropertiesResponse); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*sapcontrolapi.GetInstancePropertiesResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *sapcontrolapi.GetInstancePropertiesResponse); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sapcontrolapi.GetInstancePropertiesResponse) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,22 +40,25 @@ func (_m *WebService) GetInstanceProperties() (*sapcontrolapi.GetInstancePropert return r0, r1 } -// GetProcessList provides a mock function with given fields: -func (_m *WebService) GetProcessList() (*sapcontrolapi.GetProcessListResponse, error) { - ret := _m.Called() +// GetProcessList provides a mock function with given fields: ctx +func (_m *WebService) GetProcessList(ctx context.Context) (*sapcontrolapi.GetProcessListResponse, error) { + ret := _m.Called(ctx) var r0 *sapcontrolapi.GetProcessListResponse - if rf, ok := ret.Get(0).(func() *sapcontrolapi.GetProcessListResponse); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*sapcontrolapi.GetProcessListResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *sapcontrolapi.GetProcessListResponse); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sapcontrolapi.GetProcessListResponse) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -58,22 +66,25 @@ func (_m *WebService) GetProcessList() (*sapcontrolapi.GetProcessListResponse, e return r0, r1 } -// GetSystemInstanceList provides a mock function with given fields: -func (_m *WebService) GetSystemInstanceList() (*sapcontrolapi.GetSystemInstanceListResponse, error) { - ret := _m.Called() +// GetSystemInstanceList provides a mock function with given fields: ctx +func (_m *WebService) GetSystemInstanceList(ctx context.Context) (*sapcontrolapi.GetSystemInstanceListResponse, error) { + ret := _m.Called(ctx) var r0 *sapcontrolapi.GetSystemInstanceListResponse - if rf, ok := ret.Get(0).(func() *sapcontrolapi.GetSystemInstanceListResponse); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*sapcontrolapi.GetSystemInstanceListResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *sapcontrolapi.GetSystemInstanceListResponse); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sapcontrolapi.GetSystemInstanceListResponse) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -81,22 +92,25 @@ func (_m *WebService) GetSystemInstanceList() (*sapcontrolapi.GetSystemInstanceL return r0, r1 } -// GetVersionInfo provides a mock function with given fields: -func (_m *WebService) GetVersionInfo() (*sapcontrolapi.GetVersionInfoResponse, error) { - ret := _m.Called() +// GetVersionInfo provides a mock function with given fields: ctx +func (_m *WebService) GetVersionInfo(ctx context.Context) (*sapcontrolapi.GetVersionInfoResponse, error) { + ret := _m.Called(ctx) var r0 *sapcontrolapi.GetVersionInfoResponse - if rf, ok := ret.Get(0).(func() *sapcontrolapi.GetVersionInfoResponse); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*sapcontrolapi.GetVersionInfoResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *sapcontrolapi.GetVersionInfoResponse); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sapcontrolapi.GetVersionInfoResponse) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -104,22 +118,25 @@ func (_m *WebService) GetVersionInfo() (*sapcontrolapi.GetVersionInfoResponse, e return r0, r1 } -// HACheckConfig provides a mock function with given fields: -func (_m *WebService) HACheckConfig() (*sapcontrolapi.HACheckConfigResponse, error) { - ret := _m.Called() +// HACheckConfig provides a mock function with given fields: ctx +func (_m *WebService) HACheckConfig(ctx context.Context) (*sapcontrolapi.HACheckConfigResponse, error) { + ret := _m.Called(ctx) var r0 *sapcontrolapi.HACheckConfigResponse - if rf, ok := ret.Get(0).(func() *sapcontrolapi.HACheckConfigResponse); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*sapcontrolapi.HACheckConfigResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *sapcontrolapi.HACheckConfigResponse); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sapcontrolapi.HACheckConfigResponse) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -127,22 +144,25 @@ func (_m *WebService) HACheckConfig() (*sapcontrolapi.HACheckConfigResponse, err return r0, r1 } -// HAGetFailoverConfig provides a mock function with given fields: -func (_m *WebService) HAGetFailoverConfig() (*sapcontrolapi.HAGetFailoverConfigResponse, error) { - ret := _m.Called() +// HAGetFailoverConfig provides a mock function with given fields: ctx +func (_m *WebService) HAGetFailoverConfig(ctx context.Context) (*sapcontrolapi.HAGetFailoverConfigResponse, error) { + ret := _m.Called(ctx) var r0 *sapcontrolapi.HAGetFailoverConfigResponse - if rf, ok := ret.Get(0).(func() *sapcontrolapi.HAGetFailoverConfigResponse); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*sapcontrolapi.HAGetFailoverConfigResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *sapcontrolapi.HAGetFailoverConfigResponse); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sapcontrolapi.HAGetFailoverConfigResponse) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -150,13 +170,12 @@ func (_m *WebService) HAGetFailoverConfig() (*sapcontrolapi.HAGetFailoverConfigR return r0, r1 } -type NewWebServiceT interface { +// NewWebService creates a new instance of WebService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewWebService(t interface { mock.TestingT Cleanup(func()) -} - -// NewWebService creates a new instance of WebService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewWebService(t NewWebServiceT) *WebService { +}) *WebService { mock := &WebService{} mock.Mock.Test(t) diff --git a/internal/core/sapsystem/sapcontrolapi/webservice.go b/internal/core/sapsystem/sapcontrolapi/webservice.go index 38ac0b4f..173ec142 100644 --- a/internal/core/sapsystem/sapcontrolapi/webservice.go +++ b/internal/core/sapsystem/sapcontrolapi/webservice.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "path" + "time" "github.com/hooklift/gowsdl/soap" ) @@ -16,12 +17,12 @@ import ( //go:generate mockery --all type WebService interface { - GetInstanceProperties() (*GetInstancePropertiesResponse, error) - GetProcessList() (*GetProcessListResponse, error) - GetSystemInstanceList() (*GetSystemInstanceListResponse, error) - GetVersionInfo() (*GetVersionInfoResponse, error) - HACheckConfig() (*HACheckConfigResponse, error) - HAGetFailoverConfig() (*HAGetFailoverConfigResponse, error) + GetInstanceProperties(ctx context.Context) (*GetInstancePropertiesResponse, error) + GetProcessList(ctx context.Context) (*GetProcessListResponse, error) + GetSystemInstanceList(ctx context.Context) (*GetSystemInstanceListResponse, error) + GetVersionInfo(ctx context.Context) (*GetVersionInfoResponse, error) + HACheckConfig(ctx context.Context) (*HACheckConfigResponse, error) + HAGetFailoverConfig(ctx context.Context) (*HAGetFailoverConfigResponse, error) } type STATECOLOR string @@ -168,6 +169,7 @@ func NewWebServiceUnix(instNumber string) WebService { socket := path.Join("/tmp", fmt.Sprintf(".sapstream5%s13", instNumber)) udsClient := &http.Client{ + Timeout: 30 * time.Second, Transport: &http.Transport{ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { d := net.Dialer{} @@ -186,10 +188,10 @@ func NewWebServiceUnix(instNumber string) WebService { } // GetInstanceProperties returns a list of available instance features and information how to get it. -func (s *webService) GetInstanceProperties() (*GetInstancePropertiesResponse, error) { +func (s *webService) GetInstanceProperties(ctx context.Context) (*GetInstancePropertiesResponse, error) { request := &GetInstanceProperties{} response := &GetInstancePropertiesResponse{} - err := s.client.Call("''", request, response) + err := s.client.CallContext(ctx, "''", request, response) if err != nil { return nil, err } @@ -199,10 +201,10 @@ func (s *webService) GetInstanceProperties() (*GetInstancePropertiesResponse, er // GetProcessList returns a list of all processes directly started by the webservice // according to the SAP start profile. -func (s *webService) GetProcessList() (*GetProcessListResponse, error) { +func (s *webService) GetProcessList(ctx context.Context) (*GetProcessListResponse, error) { request := &GetProcessList{} response := &GetProcessListResponse{} - err := s.client.Call("''", request, response) + err := s.client.CallContext(ctx, "''", request, response) if err != nil { return nil, err } @@ -212,10 +214,10 @@ func (s *webService) GetProcessList() (*GetProcessListResponse, error) { // GetSystemInstanceList returns a list of all processes directly started by the webservice // according to the SAP start profile. -func (s *webService) GetSystemInstanceList() (*GetSystemInstanceListResponse, error) { +func (s *webService) GetSystemInstanceList(ctx context.Context) (*GetSystemInstanceListResponse, error) { request := &GetSystemInstanceList{} response := &GetSystemInstanceListResponse{} - err := s.client.Call("''", request, response) + err := s.client.CallContext(ctx, "''", request, response) if err != nil { return nil, err } @@ -224,10 +226,10 @@ func (s *webService) GetSystemInstanceList() (*GetSystemInstanceListResponse, er } // GetVersionInfo returns a list version information for the most important files of the instance -func (s *webService) GetVersionInfo() (*GetVersionInfoResponse, error) { +func (s *webService) GetVersionInfo(ctx context.Context) (*GetVersionInfoResponse, error) { request := &GetVersionInfo{} response := &GetVersionInfoResponse{} - err := s.client.Call("''", request, response) + err := s.client.CallContext(ctx, "''", request, response) if err != nil { return nil, err } @@ -236,10 +238,10 @@ func (s *webService) GetVersionInfo() (*GetVersionInfoResponse, error) { } // HACheckConfig checks high availability configurration and status of the system -func (s *webService) HACheckConfig() (*HACheckConfigResponse, error) { +func (s *webService) HACheckConfig(ctx context.Context) (*HACheckConfigResponse, error) { request := &HACheckConfig{} response := &HACheckConfigResponse{} - err := s.client.Call("''", request, &response) + err := s.client.CallContext(ctx, "''", request, &response) if err != nil { return nil, err } @@ -248,10 +250,10 @@ func (s *webService) HACheckConfig() (*HACheckConfigResponse, error) { } // HAGetFailoverConfig returns HA failover third party information -func (s *webService) HAGetFailoverConfig() (*HAGetFailoverConfigResponse, error) { +func (s *webService) HAGetFailoverConfig(ctx context.Context) (*HAGetFailoverConfigResponse, error) { request := &HAGetFailoverConfig{} response := &HAGetFailoverConfigResponse{} - err := s.client.Call("''", request, &response) + err := s.client.CallContext(ctx, "''", request, &response) if err != nil { return nil, err } diff --git a/internal/core/sapsystem/sapinstance.go b/internal/core/sapsystem/sapinstance.go index 498fe203..eb4b18e6 100644 --- a/internal/core/sapsystem/sapinstance.go +++ b/internal/core/sapsystem/sapinstance.go @@ -1,6 +1,7 @@ package sapsystem import ( + "context" "fmt" "os" "path" @@ -34,13 +35,17 @@ type SAPInstance struct { HdbnsutilSRstate HdbnsutilSRstate } -func NewSAPInstance(w sapcontrolapi.WebService, executor utils.CommandExecutor) (*SAPInstance, error) { +func NewSAPInstance( + ctx context.Context, + w sapcontrolapi.WebService, + executor utils.CommandExecutor, +) (*SAPInstance, error) { host, err := os.Hostname() if err != nil { return nil, err } - scontrol, err := NewSAPControl(w) + scontrol, err := NewSAPControl(ctx, w) if err != nil { return nil, err } diff --git a/internal/core/sapsystem/sapsystem.go b/internal/core/sapsystem/sapsystem.go index 73b8ee21..add8857b 100644 --- a/internal/core/sapsystem/sapsystem.go +++ b/internal/core/sapsystem/sapsystem.go @@ -2,6 +2,7 @@ package sapsystem import ( "bufio" + "context" "crypto/md5" //nolint:gosec "fmt" "io" @@ -81,8 +82,9 @@ func Md5sum(data string) string { return fmt.Sprintf("%x", md5.Sum([]byte(data))) //nolint:gosec } -func NewDefaultSAPSystemsList() (SAPSystemsList, error) { +func NewDefaultSAPSystemsList(ctx context.Context) (SAPSystemsList, error) { return NewSAPSystemsList( + ctx, afero.NewOsFs(), utils.Executor{}, sapcontrolapi.WebServiceUnix{}, @@ -90,6 +92,7 @@ func NewDefaultSAPSystemsList() (SAPSystemsList, error) { } func NewSAPSystemsList( + ctx context.Context, fs afero.Fs, executor utils.CommandExecutor, webService sapcontrolapi.WebServiceConnector, @@ -103,7 +106,7 @@ func NewSAPSystemsList( // Find systems for _, sysPath := range systemPaths { - system, err := NewSAPSystem(fs, executor, webService, sysPath) + system, err := NewSAPSystem(ctx, fs, executor, webService, sysPath) if err != nil { log.Printf("Error discovering a SAP system: %s", err) continue @@ -125,6 +128,7 @@ func (sl SAPSystemsList) GetSIDsString() string { } func NewSAPSystem( + ctx context.Context, fs afero.Fs, executor utils.CommandExecutor, webService sapcontrolapi.WebServiceConnector, @@ -151,7 +155,7 @@ func NewSAPSystem( // Find instances for _, instPath := range instPaths { webService := webService.New(instPath[1]) - instance, err := NewSAPInstance(webService, executor) + instance, err := NewSAPInstance(ctx, webService, executor) if err != nil { log.Errorf("Error discovering a SAP instance: %s", err) continue diff --git a/internal/core/sapsystem/sapsystem_test.go b/internal/core/sapsystem/sapsystem_test.go index 8ee87e40..82675902 100644 --- a/internal/core/sapsystem/sapsystem_test.go +++ b/internal/core/sapsystem/sapsystem_test.go @@ -2,6 +2,7 @@ package sapsystem_test import ( + "context" "io" "os" "testing" @@ -28,8 +29,8 @@ func TestSAPSystemTestSuite(t *testing.T) { func fakeNewWebService(instName string, features string) sapcontrolapi.WebService { mockWebService := new(sapControlMocks.WebService) - - mockWebService.On("GetInstanceProperties").Return(&sapcontrol.GetInstancePropertiesResponse{ + ctx := context.TODO() + mockWebService.On("GetInstanceProperties", ctx).Return(&sapcontrol.GetInstancePropertiesResponse{ Properties: []*sapcontrol.InstanceProperty{ { Property: "SAPSYSTEMNAME", @@ -49,11 +50,11 @@ func fakeNewWebService(instName string, features string) sapcontrolapi.WebServic }, }, nil) - mockWebService.On("GetProcessList").Return(&sapcontrol.GetProcessListResponse{ + mockWebService.On("GetProcessList", ctx).Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{}, }, nil) - mockWebService.On("GetSystemInstanceList").Return(&sapcontrol.GetSystemInstanceListResponse{ + mockWebService.On("GetSystemInstanceList", ctx).Return(&sapcontrol.GetSystemInstanceListResponse{ Instances: []*sapcontrol.SAPInstance{ { Hostname: "host", @@ -116,6 +117,7 @@ func mockSappfpar() []byte { } func (suite *SAPSystemTestSuite) TestNewSAPSystemsList() { + ctx := context.TODO() appFS := afero.NewMemMapFs() err := appFS.MkdirAll("/usr/sap/DEV/ASCS01", 0755) suite.NoError(err) @@ -132,7 +134,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPSystemsList() { mockWebServiceConnector.On("New", "01").Return(fakeNewWebService("ASCS01", "")) mockWebServiceConnector.On("New", "02").Return(fakeNewWebService("ERS02", "")) - systems, err := sapsystem.NewSAPSystemsList(appFS, mockCommand, mockWebServiceConnector) + systems, err := sapsystem.NewSAPSystemsList(ctx, appFS, mockCommand, mockWebServiceConnector) suite.Len(systems, 2) suite.Equal(systems[0].SID, "DEV") @@ -141,6 +143,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPSystemsList() { } func (suite *SAPSystemTestSuite) TestNewSAPSystem() { + ctx := context.TODO() mockCommand := new(utilsMocks.CommandExecutor) mockWebServiceConnector := new(sapControlMocks.WebServiceConnector) mockWebServiceConnector.On("New", "01").Return(fakeNewWebService("ASCS01", "")) @@ -194,7 +197,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPSystem() { sappfparCmd := "sappfpar SAPSYSTEMNAME SAPGLOBALHOST SAPFQDN SAPDBHOST dbs/hdb/dbname dbs/hdb/schema rdisp/msp/msserv rdisp/msserv_internal name=DEV" mockCommand.On("Exec", "su", "-lc", sappfparCmd, "devadm").Return(mockSappfpar(), nil) - system, err := sapsystem.NewSAPSystem(appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") + system, err := sapsystem.NewSAPSystem(ctx, appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") suite.Equal(sapsystem.Unknown, system.Type) suite.Contains("ASCS01", system.Instances[0].Name) @@ -216,6 +219,7 @@ func mockSystemReplicationStatus() []byte { } func (suite *SAPSystemTestSuite) TestDetectSystemId_Database() { + ctx := context.TODO() appFS, err := mockDEVFileSystem() suite.NoError(err) nameserverContent := []byte(` @@ -241,7 +245,7 @@ key2 = value2 On("Exec", "su", "-lc", "/usr/sap/DEV/HDB00/exe/hdbnsutil -sr_state -sapcontrol=1", "devadm"). Return(mockHdbnsutilSrstate(), nil) - system, err := sapsystem.NewSAPSystem(appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") + system, err := sapsystem.NewSAPSystem(ctx, appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") suite.Equal("089d1a278481b86e821237f8e98e6de7", system.ID) suite.Equal(sapsystem.Database, system.Type) @@ -249,6 +253,7 @@ key2 = value2 } func (suite *SAPSystemTestSuite) TestDetectSystemId_Application() { + ctx := context.TODO() appFS, err := mockDEVFileSystem() suite.NoError(err) mockCommand := new(utilsMocks.CommandExecutor) @@ -258,7 +263,7 @@ func (suite *SAPSystemTestSuite) TestDetectSystemId_Application() { sappfparCmd := "sappfpar SAPSYSTEMNAME SAPGLOBALHOST SAPFQDN SAPDBHOST dbs/hdb/dbname dbs/hdb/schema rdisp/msp/msserv rdisp/msserv_internal name=DEV" mockCommand.On("Exec", "su", "-lc", sappfparCmd, "devadm").Return(mockSappfpar(), nil) - system, err := sapsystem.NewSAPSystem(appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") + system, err := sapsystem.NewSAPSystem(ctx, appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") suite.Equal("089d1a278481b86e821237f8e98e6de7", system.ID) suite.Equal(sapsystem.Application, system.Type) @@ -266,6 +271,7 @@ func (suite *SAPSystemTestSuite) TestDetectSystemId_Application() { } func (suite *SAPSystemTestSuite) TestDetectSystemId_Diagnostics() { + ctx := context.TODO() appFS, err := mockDEVFileSystem() suite.NoError(err) machineIDContent := []byte(`dummy-machine-id`) @@ -280,7 +286,7 @@ func (suite *SAPSystemTestSuite) TestDetectSystemId_Diagnostics() { mockWebServiceConnector.On("New", "01").Return(fakeNewWebService("HDB00", "SMDAGENT")) - system, err := sapsystem.NewSAPSystem(appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") + system, err := sapsystem.NewSAPSystem(ctx, appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") suite.Equal("d3d5dd5ec501127e0011a2531e3b11ff", system.ID) suite.Equal(sapsystem.DiagnosticsAgent, system.Type) @@ -288,6 +294,7 @@ func (suite *SAPSystemTestSuite) TestDetectSystemId_Diagnostics() { } func (suite *SAPSystemTestSuite) TestDetectSystemId_Unknown() { + ctx := context.TODO() appFS, err := mockDEVFileSystem() suite.NoError(err) mockCommand := new(utilsMocks.CommandExecutor) @@ -295,7 +302,7 @@ func (suite *SAPSystemTestSuite) TestDetectSystemId_Unknown() { mockWebServiceConnector.On("New", "01").Return(fakeNewWebService("HDB00", "UNKNOWN")) - system, err := sapsystem.NewSAPSystem(appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") + system, err := sapsystem.NewSAPSystem(ctx, appFS, mockCommand, mockWebServiceConnector, "/usr/sap/DEV") suite.Equal("-", system.ID) suite.Equal(sapsystem.Unknown, system.Type) @@ -366,10 +373,11 @@ func (suite *SAPSystemTestSuite) TestGetDBAddress_ResolveError() { } func (suite *SAPSystemTestSuite) TestNewSAPInstanceDatabase() { + ctx := context.TODO() mockWebService := new(sapControlMocks.WebService) mockCommand := new(utilsMocks.CommandExecutor) - mockWebService.On("GetInstanceProperties").Return(&sapcontrol.GetInstancePropertiesResponse{ + mockWebService.On("GetInstanceProperties", ctx).Return(&sapcontrol.GetInstancePropertiesResponse{ Properties: []*sapcontrol.InstanceProperty{ { Property: "prop1", @@ -394,7 +402,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceDatabase() { }, }, nil) - mockWebService.On("GetProcessList").Return(&sapcontrol.GetProcessListResponse{ + mockWebService.On("GetProcessList", ctx).Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{ { Name: "enserver", @@ -417,7 +425,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceDatabase() { }, }, nil) - mockWebService.On("GetSystemInstanceList").Return(&sapcontrol.GetSystemInstanceListResponse{ + mockWebService.On("GetSystemInstanceList", ctx).Return(&sapcontrol.GetSystemInstanceListResponse{ Instances: []*sapcontrol.SAPInstance{ { Hostname: "host1", @@ -452,7 +460,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceDatabase() { mockHdbnsutilSrstate(), nil, ) - sapInstance, _ := sapsystem.NewSAPInstance(mockWebService, mockCommand) + sapInstance, _ := sapsystem.NewSAPInstance(ctx, mockWebService, mockCommand) host, _ := os.Hostname() expectedInstance := &sapsystem.SAPInstance{ @@ -615,9 +623,10 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceDatabase() { } func (suite *SAPSystemTestSuite) TestNewSAPInstanceApp() { + ctx := context.TODO() mockWebService := new(sapControlMocks.WebService) - mockWebService.On("GetInstanceProperties").Return(&sapcontrol.GetInstancePropertiesResponse{ + mockWebService.On("GetInstanceProperties", ctx).Return(&sapcontrol.GetInstancePropertiesResponse{ Properties: []*sapcontrol.InstanceProperty{ { Property: "prop1", @@ -642,7 +651,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceApp() { }, }, nil) - mockWebService.On("GetProcessList").Return(&sapcontrol.GetProcessListResponse{ + mockWebService.On("GetProcessList", ctx).Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{ { Name: "enserver", @@ -665,7 +674,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceApp() { }, }, nil) - mockWebService.On("GetSystemInstanceList").Return(&sapcontrol.GetSystemInstanceListResponse{ + mockWebService.On("GetSystemInstanceList", ctx).Return(&sapcontrol.GetSystemInstanceListResponse{ Instances: []*sapcontrol.SAPInstance{ { Hostname: "host1", @@ -688,7 +697,7 @@ func (suite *SAPSystemTestSuite) TestNewSAPInstanceApp() { }, }, nil) - sapInstance, _ := sapsystem.NewSAPInstance(mockWebService, new(utilsMocks.CommandExecutor)) + sapInstance, _ := sapsystem.NewSAPInstance(ctx, mockWebService, new(utilsMocks.CommandExecutor)) host, _ := os.Hostname() expectedInstance := &sapsystem.SAPInstance{ @@ -903,6 +912,7 @@ func (suite *SAPSystemTestSuite) TestFindProfiles() { } func (suite *SAPSystemTestSuite) TestDetectType() { + ctx := context.TODO() cases := []struct { instance *sapcontrol.SAPInstance expectedType sapsystem.SystemType @@ -947,7 +957,7 @@ func (suite *SAPSystemTestSuite) TestDetectType() { for _, tt := range cases { mockWebService := new(sapControlMocks.WebService) mockWebService. - On("GetInstanceProperties"). + On("GetInstanceProperties", ctx). Return(&sapcontrol.GetInstancePropertiesResponse{ Properties: []*sapcontrol.InstanceProperty{ { @@ -967,16 +977,16 @@ func (suite *SAPSystemTestSuite) TestDetectType() { }, }, }, nil). - On("GetProcessList"). + On("GetProcessList", ctx). Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{}, }, nil). - On("GetSystemInstanceList").Return(&sapcontrol.GetSystemInstanceListResponse{ + On("GetSystemInstanceList", ctx).Return(&sapcontrol.GetSystemInstanceListResponse{ Instances: []*sapcontrol.SAPInstance{tt.instance}, }, nil) mockCommand := new(mocks.CommandExecutor) - instance, err := sapsystem.NewSAPInstance(mockWebService, mockCommand) + instance, err := sapsystem.NewSAPInstance(ctx, mockWebService, mockCommand) suite.NoError(err) suite.Equal(tt.expectedType, instance.Type) @@ -984,9 +994,10 @@ func (suite *SAPSystemTestSuite) TestDetectType() { } func (suite *SAPSystemTestSuite) TestDetectType_Database() { + ctx := context.TODO() mockWebService := new(sapControlMocks.WebService) mockWebService. - On("GetInstanceProperties"). + On("GetInstanceProperties", ctx). Return(&sapcontrol.GetInstancePropertiesResponse{ Properties: []*sapcontrol.InstanceProperty{ { @@ -1006,11 +1017,11 @@ func (suite *SAPSystemTestSuite) TestDetectType_Database() { }, }, }, nil). - On("GetProcessList"). + On("GetProcessList", ctx). Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{}, }, nil). - On("GetSystemInstanceList").Return(&sapcontrol.GetSystemInstanceListResponse{ + On("GetSystemInstanceList", ctx).Return(&sapcontrol.GetSystemInstanceListResponse{ Instances: []*sapcontrol.SAPInstance{ { Hostname: "host1", @@ -1032,7 +1043,7 @@ func (suite *SAPSystemTestSuite) TestDetectType_Database() { On("Exec", "su", "-lc", "/usr/sap/HDB/HDB00/exe/hdbnsutil -sr_state -sapcontrol=1", "hdbadm"). Return(mockHdbnsutilSrstate(), nil) - instance, err := sapsystem.NewSAPInstance(mockWebService, mockCommand) + instance, err := sapsystem.NewSAPInstance(ctx, mockWebService, mockCommand) suite.NoError(err) suite.Equal(sapsystem.Database, instance.Type) diff --git a/internal/discovery/cloud.go b/internal/discovery/cloud.go index 9f1f3778..96d1855b 100644 --- a/internal/discovery/cloud.go +++ b/internal/discovery/cloud.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "fmt" "net/http" "time" @@ -36,14 +37,14 @@ func (d CloudDiscovery) GetInterval() time.Duration { return d.interval } -func (d CloudDiscovery) Discover() (string, error) { - client := &http.Client{Transport: &http.Transport{Proxy: nil}} - cloudData, err := cloud.NewCloudInstance(utils.Executor{}, client) +func (d CloudDiscovery) Discover(ctx context.Context) (string, error) { + client := &http.Client{Transport: &http.Transport{Proxy: nil}, Timeout: 30 * time.Second} + cloudData, err := cloud.NewCloudInstance(ctx, utils.Executor{}, client) if err != nil { return "", err } - err = d.collectorClient.Publish(d.id, cloudData) + err = d.collectorClient.Publish(ctx, d.id, cloudData) if err != nil { log.Debugf("Error while sending cloud discovery to data collector: %s", err) return "", err diff --git a/internal/discovery/cluster.go b/internal/discovery/cluster.go index 64c202a2..2cb31cbc 100644 --- a/internal/discovery/cluster.go +++ b/internal/discovery/cluster.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "fmt" "time" @@ -36,14 +37,14 @@ func (c ClusterDiscovery) GetInterval() time.Duration { } // Execute one iteration of a discovery and publish the results to the collector -func (c ClusterDiscovery) Discover() (string, error) { +func (c ClusterDiscovery) Discover(ctx context.Context) (string, error) { cluster, err := cluster.NewCluster() if err != nil { log.Debugf("Error creating the cluster data object: %s", err) } - err = c.collectorClient.Publish(c.id, cluster) + err = c.collectorClient.Publish(ctx, c.id, cluster) if err != nil { log.Debugf("Error while sending cluster discovery to data collector: %s", err) return "", err diff --git a/internal/discovery/collector/client.go b/internal/discovery/collector/client.go index 0cee288b..eaeaf8a9 100644 --- a/internal/discovery/collector/client.go +++ b/internal/discovery/collector/client.go @@ -2,6 +2,7 @@ package collector import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -10,8 +11,8 @@ import ( ) type Client interface { - Publish(discoveryType string, payload interface{}) error - Heartbeat() error + Publish(ctx context.Context, discoveryType string, payload interface{}) error + Heartbeat(ctx context.Context) error } type Collector struct { @@ -32,7 +33,7 @@ func NewCollectorClient(config *Config, httpClient *http.Client) *Collector { } } -func (c *Collector) Publish(discoveryType string, payload interface{}) error { +func (c *Collector) Publish(ctx context.Context, discoveryType string, payload interface{}) error { log.Debugf("Sending %s to data collector", discoveryType) requestBody, err := json.Marshal(map[string]interface{}{ @@ -46,7 +47,7 @@ func (c *Collector) Publish(discoveryType string, payload interface{}) error { url := fmt.Sprintf("%s/api/v1/collect", c.config.ServerURL) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody)) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestBody)) if err != nil { return err } @@ -67,10 +68,10 @@ func (c *Collector) Publish(discoveryType string, payload interface{}) error { return nil } -func (c *Collector) Heartbeat() error { +func (c *Collector) Heartbeat(ctx context.Context) error { url := fmt.Sprintf("%s/api/v1/hosts/%s/heartbeat", c.config.ServerURL, c.config.AgentID) - req, err := http.NewRequest("POST", url, nil) + req, err := http.NewRequestWithContext(ctx, "POST", url, nil) if err != nil { return err } diff --git a/internal/discovery/collector/client_test.go b/internal/discovery/collector/client_test.go index 9d3a3255..81f3637f 100644 --- a/internal/discovery/collector/client_test.go +++ b/internal/discovery/collector/client_test.go @@ -1,6 +1,7 @@ package collector_test import ( + "context" "encoding/json" "fmt" "io" @@ -42,6 +43,7 @@ func (suite *CollectorClientTestSuite) SetupSuite() { } func (suite *CollectorClientTestSuite) TestCollectorClientPublishingSuccess() { + ctx := context.TODO() discoveredDataPayload := struct { FieldA string }{ @@ -68,12 +70,13 @@ func (suite *CollectorClientTestSuite) TestCollectorClientPublishingSuccess() { } }) - err := suite.collectorClient.Publish(discoveryType, discoveredDataPayload) + err := suite.collectorClient.Publish(ctx, discoveryType, discoveredDataPayload) suite.NoError(err) } func (suite *CollectorClientTestSuite) TestCollectorClientPublishingFailure() { + ctx := context.TODO() suite.httpClient.Transport = helpers.RoundTripFunc(func(req *http.Request) *http.Response { suite.Equal(req.URL.String(), "https://localhost/api/v1/collect") return &http.Response{ @@ -81,19 +84,20 @@ func (suite *CollectorClientTestSuite) TestCollectorClientPublishingFailure() { } }) - err := suite.collectorClient.Publish("some_discovery_type", struct{}{}) + err := suite.collectorClient.Publish(ctx, "some_discovery_type", struct{}{}) suite.Error(err) } func (suite *CollectorClientTestSuite) TestCollectorClientHeartbeat() { + ctx := context.TODO() suite.httpClient.Transport = helpers.RoundTripFunc(func(req *http.Request) *http.Response { suite.Equal(req.URL.String(), fmt.Sprintf("https://localhost/api/v1/hosts/%s/heartbeat", DummyAgentID)) return &http.Response{ StatusCode: 204, } }) - err := suite.collectorClient.Heartbeat() + err := suite.collectorClient.Heartbeat(ctx) suite.NoError(err) } diff --git a/internal/discovery/collector/publishing_test.go b/internal/discovery/collector/publishing_test.go index 594dd86c..c9175465 100644 --- a/internal/discovery/collector/publishing_test.go +++ b/internal/discovery/collector/publishing_test.go @@ -2,6 +2,7 @@ package collector_test import ( + "context" "encoding/json" "io" "net/http" @@ -108,6 +109,7 @@ func (suite *PublishingTestSuite) TestCollectorClientPublishingSAPSystemDiagnost type AssertionFunc func(requestBodyAgainstCollector string) func (suite *PublishingTestSuite) runDiscoveryScenario(discoveryType string, payload interface{}, assertion AssertionFunc) { + ctx := context.TODO() collectorClient := suite.configuredClient suite.httpClient.Transport = helpers.RoundTripFunc(func(req *http.Request) *http.Response { @@ -132,7 +134,7 @@ func (suite *PublishingTestSuite) runDiscoveryScenario(discoveryType string, pay } }) - err := collectorClient.Publish(discoveryType, payload) + err := collectorClient.Publish(ctx, discoveryType, payload) suite.NoError(err) } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index cdfeed9f..486dd540 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "time" "github.com/trento-project/agent/internal/discovery/collector" @@ -24,7 +25,7 @@ type Discovery interface { // Returns an arbitrary unique string identifier of the discovery GetID() string // Execute the discovery mechanism - Discover() (string, error) + Discover(ctx context.Context) (string, error) // Get interval GetInterval() time.Duration } diff --git a/internal/discovery/host.go b/internal/discovery/host.go index 894d2477..dd3f47c1 100644 --- a/internal/discovery/host.go +++ b/internal/discovery/host.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "fmt" "net" "strconv" @@ -43,7 +44,7 @@ func (d HostDiscovery) GetInterval() time.Duration { } // Execute one iteration of a discovery and publish to the collector -func (d HostDiscovery) Discover() (string, error) { +func (d HostDiscovery) Discover(ctx context.Context) (string, error) { ipAddresses, err := getHostIPAddresses() if err != nil { return "", err @@ -60,7 +61,7 @@ func (d HostDiscovery) Discover() (string, error) { InstallationSource: version.InstallationSource, } - err = d.collectorClient.Publish(d.id, host) + err = d.collectorClient.Publish(ctx, d.id, host) if err != nil { log.Debugf("Error while sending host discovery to data collector: %s", err) return "", err diff --git a/internal/discovery/sapsystem.go b/internal/discovery/sapsystem.go index 089e7d94..6bb0b866 100644 --- a/internal/discovery/sapsystem.go +++ b/internal/discovery/sapsystem.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "fmt" "time" @@ -34,14 +35,14 @@ func (d SAPSystemsDiscovery) GetInterval() time.Duration { return d.interval } -func (d SAPSystemsDiscovery) Discover() (string, error) { - systems, err := sapsystem.NewDefaultSAPSystemsList() +func (d SAPSystemsDiscovery) Discover(ctx context.Context) (string, error) { + systems, err := sapsystem.NewDefaultSAPSystemsList(ctx) if err != nil { return "", err } - err = d.collectorClient.Publish(d.id, systems) + err = d.collectorClient.Publish(ctx, d.id, systems) if err != nil { log.Debugf("Error while sending sapsystem discovery to data collector: %s", err) return "", err diff --git a/internal/discovery/saptune.go b/internal/discovery/saptune.go index 378d1cce..5a510046 100644 --- a/internal/discovery/saptune.go +++ b/internal/discovery/saptune.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "encoding/json" "time" @@ -41,7 +42,7 @@ func (d SaptuneDiscovery) GetInterval() time.Duration { return d.interval } -func (d SaptuneDiscovery) Discover() (string, error) { +func (d SaptuneDiscovery) Discover(ctx context.Context) (string, error) { var saptunePayload SaptuneDiscoveryPayload saptuneRetriever, err := saptune.NewSaptune(utils.Executor{}) @@ -73,7 +74,7 @@ func (d SaptuneDiscovery) Discover() (string, error) { } } - err = d.collectorClient.Publish(d.id, saptunePayload) + err = d.collectorClient.Publish(ctx, d.id, saptunePayload) if err != nil { log.Debugf("Error while sending saptune discovery to data collector: %s", err) return "", err diff --git a/internal/discovery/subscription.go b/internal/discovery/subscription.go index 4f129a06..2e27f3fb 100644 --- a/internal/discovery/subscription.go +++ b/internal/discovery/subscription.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "fmt" "time" @@ -37,13 +38,13 @@ func (d SubscriptionDiscovery) GetInterval() time.Duration { return d.interval } -func (d SubscriptionDiscovery) Discover() (string, error) { +func (d SubscriptionDiscovery) Discover(ctx context.Context) (string, error) { subsData, err := subscription.NewSubscriptions(utils.Executor{}) if err != nil { return "", err } - err = d.collectorClient.Publish(d.id, subsData) + err = d.collectorClient.Publish(ctx, d.id, subsData) if err != nil { log.Debugf("Error while sending subscription discovery to data collector: %s", err) return "", err diff --git a/internal/factsengine/gatherers/sapcontrol.go b/internal/factsengine/gatherers/sapcontrol.go index 5ef270ee..30f99f3b 100644 --- a/internal/factsengine/gatherers/sapcontrol.go +++ b/internal/factsengine/gatherers/sapcontrol.go @@ -1,6 +1,7 @@ package gatherers import ( + "context" "encoding/json" "fmt" "path/filepath" @@ -18,7 +19,7 @@ const ( ) // nolint:gochecknoglobals -var whitelistedSapControlArguments = map[string]func(sapcontrolapi.WebService) (interface{}, error){ +var whitelistedSapControlArguments = map[string]func(context.Context, sapcontrolapi.WebService) (interface{}, error){ "GetProcessList": mapGetProcessList, "GetSystemInstanceList": mapGetSystemInstanceList, "GetVersionInfo": mapGetVersionInfo, @@ -104,6 +105,7 @@ func NewSapControlGatherer(webService sapcontrolapi.WebServiceConnector, fs afer } func (s *SapControlGatherer) Gather(factsRequests []entities.FactRequest) ([]entities.Fact, error) { + ctx := context.Background() cachedFacts := make(map[string]entities.Fact) log.Infof("Starting %s facts gathering process", SapControlGathererName) @@ -148,7 +150,7 @@ func (s *SapControlGatherer) Gather(factsRequests []entities.FactRequest) ([]ent for _, instanceData := range instances { instanceName, instanceNumber := instanceData[0], instanceData[1] conn := s.webService.New(instanceNumber) - output, err := webmethod(conn) + output, err := webmethod(ctx, conn) if err != nil { log.Error(SapcontrolWebmethodError. Wrap(fmt.Sprintf("argument %s for %s/%s", factReq.Argument, sid, instanceName)). @@ -205,8 +207,8 @@ func initSystemsMap(fs afero.Fs) (map[string][][]string, error) { return foundSystems, err } -func mapGetProcessList(conn sapcontrolapi.WebService) (interface{}, error) { - output, err := conn.GetProcessList() +func mapGetProcessList(ctx context.Context, conn sapcontrolapi.WebService) (interface{}, error) { + output, err := conn.GetProcessList(ctx) if err != nil { return nil, err } @@ -214,8 +216,8 @@ func mapGetProcessList(conn sapcontrolapi.WebService) (interface{}, error) { return output.Processes, nil } -func mapGetSystemInstanceList(conn sapcontrolapi.WebService) (interface{}, error) { - output, err := conn.GetSystemInstanceList() +func mapGetSystemInstanceList(ctx context.Context, conn sapcontrolapi.WebService) (interface{}, error) { + output, err := conn.GetSystemInstanceList(ctx) if err != nil { return nil, err } @@ -223,8 +225,8 @@ func mapGetSystemInstanceList(conn sapcontrolapi.WebService) (interface{}, error return output.Instances, nil } -func mapGetVersionInfo(conn sapcontrolapi.WebService) (interface{}, error) { - output, err := conn.GetVersionInfo() +func mapGetVersionInfo(ctx context.Context, conn sapcontrolapi.WebService) (interface{}, error) { + output, err := conn.GetVersionInfo(ctx) if err != nil { return nil, err } @@ -252,8 +254,8 @@ func mapGetVersionInfo(conn sapcontrolapi.WebService) (interface{}, error) { return versions, nil } -func mapHACheckConfig(conn sapcontrolapi.WebService) (interface{}, error) { - output, err := conn.HACheckConfig() +func mapHACheckConfig(ctx context.Context, conn sapcontrolapi.WebService) (interface{}, error) { + output, err := conn.HACheckConfig(ctx) if err != nil { return nil, err } @@ -261,8 +263,8 @@ func mapHACheckConfig(conn sapcontrolapi.WebService) (interface{}, error) { return output.Checks, nil } -func mapHAGetFailoverConfig(conn sapcontrolapi.WebService) (interface{}, error) { - output, err := conn.HAGetFailoverConfig() +func mapHAGetFailoverConfig(ctx context.Context, conn sapcontrolapi.WebService) (interface{}, error) { + output, err := conn.HAGetFailoverConfig(ctx) if err != nil { return nil, err } diff --git a/internal/factsengine/gatherers/sapcontrol_test.go b/internal/factsengine/gatherers/sapcontrol_test.go index eb42fec0..7b7225dd 100644 --- a/internal/factsengine/gatherers/sapcontrol_test.go +++ b/internal/factsengine/gatherers/sapcontrol_test.go @@ -1,6 +1,7 @@ package gatherers_test import ( + "context" "fmt" "testing" @@ -105,8 +106,9 @@ func (suite *SapControlGathererSuite) TestSapControlGathererEmptyFileSystem() { } func (suite *SapControlGathererSuite) TestSapControlGathererCacheHit() { + ctx := context.Background() mockWebService := new(sapControlMocks.WebService) - mockWebService.On("GetProcessList").Return(&sapcontrol.GetProcessListResponse{ + mockWebService.On("GetProcessList", ctx).Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{ { Name: "process1", @@ -213,6 +215,7 @@ func (suite *SapControlGathererSuite) TestSapControlGathererCacheHit() { } func (suite *SapControlGathererSuite) TestSapControlGathererMultipleInstaces() { + ctx := context.Background() testFS := afero.NewMemMapFs() err := testFS.MkdirAll("/usr/sap/PRD/ASCS00", 0644) suite.NoError(err) @@ -224,7 +227,7 @@ func (suite *SapControlGathererSuite) TestSapControlGathererMultipleInstaces() { suite.NoError(err) mockWebService := new(sapControlMocks.WebService) - mockWebService.On("GetProcessList").Return(&sapcontrol.GetProcessListResponse{ + mockWebService.On("GetProcessList", ctx).Return(&sapcontrol.GetProcessListResponse{ Processes: []*sapcontrol.OSProcess{ { Name: "process1", @@ -236,7 +239,7 @@ func (suite *SapControlGathererSuite) TestSapControlGathererMultipleInstaces() { }, nil) mockWebServiceError := new(sapControlMocks.WebService) - mockWebServiceError.On("GetProcessList").Return(nil, fmt.Errorf("some error")) + mockWebServiceError.On("GetProcessList", ctx).Return(nil, fmt.Errorf("some error")) suite.webService. On("New", "00").Return(mockWebService). @@ -341,8 +344,9 @@ func (suite *SapControlGathererSuite) TestSapControlGathererMultipleInstaces() { } func (suite *SapControlGathererSuite) TestSapControlGathererGetSystemInstanceList() { + ctx := context.Background() mockWebService := new(sapControlMocks.WebService) - mockWebService.On("GetSystemInstanceList").Return(&sapcontrol.GetSystemInstanceListResponse{ + mockWebService.On("GetSystemInstanceList", ctx).Return(&sapcontrol.GetSystemInstanceListResponse{ Instances: []*sapcontrol.SAPInstance{ { Hostname: "host1", @@ -410,8 +414,9 @@ func (suite *SapControlGathererSuite) TestSapControlGathererGetSystemInstanceLis } func (suite *SapControlGathererSuite) TestSapControlGathererGetVersionInfo() { + ctx := context.Background() mockWebService := new(sapControlMocks.WebService) - mockWebService.On("GetVersionInfo").Return(&sapcontrol.GetVersionInfoResponse{ + mockWebService.On("GetVersionInfo", ctx).Return(&sapcontrol.GetVersionInfoResponse{ InstanceVersions: []*sapcontrol.VersionInfo{ { Filename: "/usr/sap/NWP/ERS10/exe/sapstartsrv", @@ -491,8 +496,9 @@ func (suite *SapControlGathererSuite) TestSapControlGathererGetVersionInfo() { } func (suite *SapControlGathererSuite) TestSapControlGathererHACheckConfig() { + ctx := context.Background() mockWebService := new(sapControlMocks.WebService) - mockWebService.On("HACheckConfig").Return(&sapcontrol.HACheckConfigResponse{ + mockWebService.On("HACheckConfig", ctx).Return(&sapcontrol.HACheckConfigResponse{ Checks: []*sapcontrol.HACheck{ { Description: "desc1", @@ -558,8 +564,9 @@ func (suite *SapControlGathererSuite) TestSapControlGathererHACheckConfig() { } func (suite *SapControlGathererSuite) TestSapControlGathererHAGetFailoverConfig() { + ctx := context.Background() mockWebService := new(sapControlMocks.WebService) - mockWebService.On("HAGetFailoverConfig").Return(&sapcontrol.HAGetFailoverConfigResponse{ + mockWebService.On("HAGetFailoverConfig", ctx).Return(&sapcontrol.HAGetFailoverConfigResponse{ HAActive: false, HANodes: &[]string{"node1"}, }, nil)