diff --git a/agent/config/config.go b/agent/config/config.go index fcf44e78fd9..ba805d2e82f 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -192,8 +192,17 @@ func NewConfig(ec2client ec2.EC2MetadataClient) (*Config, error) { config.Merge(userDataConfig(ec2client)) if config.AWSRegion == "" { - // Get it from metadata only if we need to (network io) - config.Merge(ec2MetadataConfig(ec2client)) + if config.NoIID { + // get it from AWS SDK if we don't have instance identity document + awsRegion, err := ec2client.Region() + if err != nil { + errs = append(errs, err) + } + config.AWSRegion = awsRegion + } else { + // Get it from metadata only if we need to (network io) + config.Merge(ec2MetadataConfig(ec2client)) + } } return config, config.mergeDefaultConfig(errs) diff --git a/agent/config/config_test.go b/agent/config/config_test.go index 28f2527cfb9..2c7eadff663 100644 --- a/agent/config/config_test.go +++ b/agent/config/config_test.go @@ -67,6 +67,25 @@ func TestBrokenEC2MetadataEndpoint(t *testing.T) { assert.Zero(t, config.APIEndpoint, "Endpoint env variable not set; endpoint should be blank") } +func TestGetRegionWithNoIID(t *testing.T) { + defer setTestEnv("AWS_DEFAULT_REGION", "")() + ctrl := gomock.NewController(t) + mockEc2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) + + userDataResponse := `{ "ECSAgentConfiguration":{ + "Cluster":"arn:aws:ecs:us-east-1:123456789012:cluster/my-cluster", + "APIEndpoint":"https://some-endpoint.com", + "NoIID":true + }}` + mockEc2Metadata.EXPECT().GetUserData().Return(userDataResponse, nil) + mockEc2Metadata.EXPECT().Region().Return("us-east-1", nil) + + config, err := NewConfig(mockEc2Metadata) + assert.NoError(t, err) + assert.Equal(t, config.AWSRegion, "us-east-1", "Wrong region") + assert.Equal(t, config.APIEndpoint, "https://some-endpoint.com", "Endpoint env variable not set; endpoint should be blank") +} + func TestEnvironmentConfig(t *testing.T) { defer setTestRegion()() defer setTestEnv("ECS_CLUSTER", "myCluster")() diff --git a/agent/ec2/blackhole_ec2_metadata_client.go b/agent/ec2/blackhole_ec2_metadata_client.go index a136ddae312..81601f6e4a9 100644 --- a/agent/ec2/blackhole_ec2_metadata_client.go +++ b/agent/ec2/blackhole_ec2_metadata_client.go @@ -60,3 +60,7 @@ func (blackholeMetadataClient) GetDynamicData(path string) (string, error) { func (blackholeMetadataClient) GetUserData() (string, error) { return "", errors.New("blackholed") } + +func (blackholeMetadataClient) Region() (string, error) { + return "", errors.New("blackholed") +} diff --git a/agent/ec2/ec2_metadata_client.go b/agent/ec2/ec2_metadata_client.go index 17992971439..00940015540 100644 --- a/agent/ec2/ec2_metadata_client.go +++ b/agent/ec2/ec2_metadata_client.go @@ -55,6 +55,7 @@ type HttpClient interface { GetDynamicData(string) (string, error) GetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error) GetUserData() (string, error) + Region() (string, error) } // EC2MetadataClient is the client used to get metadata from instance metadata service @@ -68,6 +69,7 @@ type EC2MetadataClient interface { PrimaryENIMAC() (string, error) InstanceID() (string, error) GetUserData() (string, error) + Region() (string, error) } type ec2MetadataClientImpl struct { @@ -155,3 +157,8 @@ func (c *ec2MetadataClientImpl) InstanceID() (string, error) { func (c *ec2MetadataClientImpl) GetUserData() (string, error) { return c.client.GetUserData() } + +// Region returns the region the instance is running in. +func (c *ec2MetadataClientImpl) Region() (string, error) { + return c.client.Region() +} diff --git a/agent/ec2/mocks/ec2_mocks.go b/agent/ec2/mocks/ec2_mocks.go index 19c60e84c0a..03f83c27e9c 100644 --- a/agent/ec2/mocks/ec2_mocks.go +++ b/agent/ec2/mocks/ec2_mocks.go @@ -141,6 +141,19 @@ func (mr *MockEC2MetadataClientMockRecorder) PrimaryENIMAC() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrimaryENIMAC", reflect.TypeOf((*MockEC2MetadataClient)(nil).PrimaryENIMAC)) } +// Region mocks base method +func (m *MockEC2MetadataClient) Region() (string, error) { + ret := m.ctrl.Call(m, "Region") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Region indicates an expected call of Region +func (mr *MockEC2MetadataClientMockRecorder) Region() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Region", reflect.TypeOf((*MockEC2MetadataClient)(nil).Region)) +} + // SubnetID mocks base method func (m *MockEC2MetadataClient) SubnetID(arg0 string) (string, error) { ret := m.ctrl.Call(m, "SubnetID", arg0) @@ -242,6 +255,19 @@ func (mr *MockHttpClientMockRecorder) GetUserData() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserData", reflect.TypeOf((*MockHttpClient)(nil).GetUserData)) } +// Region mocks base method +func (m *MockHttpClient) Region() (string, error) { + ret := m.ctrl.Call(m, "Region") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Region indicates an expected call of Region +func (mr *MockHttpClientMockRecorder) Region() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Region", reflect.TypeOf((*MockHttpClient)(nil).Region)) +} + // MockClient is a mock of Client interface type MockClient struct { ctrl *gomock.Controller