-
Notifications
You must be signed in to change notification settings - Fork 1
/
ecs.go
210 lines (177 loc) · 5.91 KB
/
ecs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
package drain
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ecs"
)
var (
ecsClient = ecs.New(session.New())
ec2client = ec2.New(session.New())
ecsRegExp = regexp.MustCompile(`ECS_CLUSTER=([0-9A-Za-z_\-]*)`)
ErrMissingUserData = errors.New("This instance seems not to have UserData")
ErrMissingECSClusterInUserData = errors.New("This instance seems not to have EcsCluster definition in UserData")
ErrInstanceTerminated = errors.New("This instance is already terminated")
)
func Drain(ecsCluster, ec2Instance string) error {
// Getting ECS Container instance representation by EC2 instance ID
instance, err := getContainerInstance(ecsCluster, ec2Instance)
if err != nil {
return err
}
printJSON("Container instance", instance)
var tasksToShutdownCount int64
if instance != nil && instance.RunningTasksCount != nil {
tasksToShutdownCount = *instance.RunningTasksCount
}
var runningTaskArns []*string
// if we have some tasks running on the instance
// we need to drain it and wait for all tasks to shutdown
for tasksToShutdownCount > 0 {
// if instance not being drained yet,
// start the drain
if *instance.Status != ecs.ContainerInstanceStatusDraining {
fmt.Println("Starting draining and waiting for all tasks to shutdown")
_, err := ecsClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
Cluster: &ecsCluster,
ContainerInstances: []*string{instance.ContainerInstanceArn},
Status: aws.String(ecs.ContainerInstanceStatusDraining),
})
if err != nil {
return err
}
// fetch list of tasks running on that instance
resp, err := ecsClient.ListTasks(&ecs.ListTasksInput{
ContainerInstance: instance.ContainerInstanceArn,
Cluster: &ecsCluster,
})
if err != nil {
return err
}
if resp != nil {
runningTaskArns = resp.TaskArns
}
// update instance information, to be sure that it started draining
instance, err = getContainerInstance(ecsCluster, ec2Instance)
if err != nil {
return err
}
}
if len(runningTaskArns) == 0 {
fmt.Println("no running tasks found")
break
}
// monitor status of the tasks running on the current instance
tasks, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: &ecsCluster,
Tasks: runningTaskArns,
})
if err != nil {
return err
}
if tasks == nil || len(tasks.Tasks) == 0 {
fmt.Println("no tasks found")
}
taskStates := map[string]int{}
tasksToShutdownCount = 0
for _, task := range tasks.Tasks {
// wait explicitly for tasks to become "STOPPED"
// other way we may stop the instance with the tasks that
// are still being in the "DEACTIVATING" state
// see https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-lifecycle.html
if task.LastStatus == nil {
continue
}
taskStates[*task.LastStatus]++
if *task.LastStatus != ecs.DesiredStatusStopped {
tasksToShutdownCount++
}
}
printJSON("Instance task states", taskStates)
time.Sleep(10 * time.Second)
}
fmt.Println("Drain finished")
return nil
}
func getContainerInstance(ecsCluster, ec2Instance string) (*ecs.ContainerInstance, error) {
var containerInstances []*ecs.ContainerInstance
err := ecsClient.ListContainerInstancesPages(&ecs.ListContainerInstancesInput{Cluster: &ecsCluster}, func(page *ecs.ListContainerInstancesOutput, lastPage bool) bool {
respInstances, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &ecsCluster,
ContainerInstances: page.ContainerInstanceArns,
})
if err != nil {
fmt.Printf("Error describing container instances: %s", err)
return false
}
containerInstances = append(containerInstances, respInstances.ContainerInstances...)
return !lastPage
})
if err != nil {
return nil, err
}
for _, i := range containerInstances {
if *i.Ec2InstanceId == ec2Instance {
return i, nil
}
}
return nil, fmt.Errorf("%q not found in the cluster %q", ec2Instance, ecsCluster)
}
func GetClusterNameFromInstanceUserData(ec2Instance string) (string, error) {
// check instance state, error if already terminated
resp, err := ec2client.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: []*string{&ec2Instance},
})
if err != nil {
return "", err
}
if len(resp.Reservations) > 0 && len(resp.Reservations[0].Instances) > 0 {
switch *resp.Reservations[0].Instances[0].State.Name {
case ec2.InstanceStateNameTerminated, ec2.InstanceStateNameShuttingDown:
return "", ErrInstanceTerminated
}
}
att, err := ec2client.DescribeInstanceAttribute(&ec2.DescribeInstanceAttributeInput{
InstanceId: &ec2Instance,
Attribute: aws.String("userData"),
})
if err != nil {
return "", err
}
// checking if we got some user data,
// if we found none, then instance probably not a part of ECS Cluster
if att == nil || att.UserData == nil || att.UserData.Value == nil {
return "", ErrMissingUserData
}
decodedUserData, err := base64.StdEncoding.DecodeString(*att.UserData.Value)
if err != nil {
return "", err
}
// Using RegExp to get actual ECS Cluster name from UserData string
val, err := parseECSClusterValue(string(decodedUserData))
if err != nil {
return "", err
}
return val, nil
}
// Fetch value of ECS_CLUSTER variable with the regexp
func parseECSClusterValue(str string) (string, error) {
m := ecsRegExp.FindAllStringSubmatch(str, -1)
if len(m) == 0 || len(m[0]) < 2 {
fmt.Printf("UserData:\n%s\n", str)
return "", ErrMissingECSClusterInUserData
}
return m[0][1], nil
}
// AWS CloudWatch Logs prints only JSONs
func printJSON(text string, data interface{}) {
if b, err := json.Marshal(data); err == nil {
fmt.Println(text, string(b))
}
}