Skip to content

Commit

Permalink
Merge pull request #657 from martina-if/autoenable-allow-ssh-flag
Browse files Browse the repository at this point in the history
Enable allow-ssh flag when a key is specified
  • Loading branch information
martina-if authored Mar 27, 2019
2 parents 956d19f + 050d3bb commit 90e1e79
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 82 deletions.
24 changes: 7 additions & 17 deletions pkg/ctl/create/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func createClusterCmd(g *cmdutils.Grouping) *cobra.Command {
group := g.New(cmd)

exampleClusterName := utils.ClusterName("", "")
exampleNodeGroupName := utils.NodeGroupName("", "")
exampleNodeGroupName := NodeGroupName("", "")

group.InFlagSet("General", func(fs *pflag.FlagSet) {
fs.StringVarP(&cfg.Metadata.Name, "name", "n", "", fmt.Sprintf("EKS cluster name (generated if unspecified, e.g. %q)", exampleClusterName))
Expand Down Expand Up @@ -176,7 +176,7 @@ func doCreateCluster(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg stri

skipNodeGroupsIfRequested(cfg)

if err := CheckEachNodeGroup(ngFilter, cfg, NewNodeGroupChecker); err != nil {
if err := setNodeGroupDefaults(ngFilter, cfg.NodeGroups); err != nil {
return err
}
} else {
Expand All @@ -194,17 +194,7 @@ func doCreateCluster(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg stri

skipNodeGroupsIfRequested(cfg)

err := CheckEachNodeGroup(ngFilter, cfg, func(i int, ng *api.NodeGroup) error {
if ng.AllowSSH && ng.SSHPublicKeyPath == "" {
return fmt.Errorf("--ssh-public-key must be non-empty string")
}

// generate nodegroup name or use flag
ng.Name = utils.NodeGroupName(ng.Name, "")

return nil
})
if err != nil {
if err := configureNodeGroups(ngFilter, cfg.NodeGroups, cmd); err != nil {
return err
}
}
Expand Down Expand Up @@ -305,7 +295,7 @@ func doCreateCluster(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg stri
return err
}

if err := CheckEachNodeGroup(ngFilter, cfg, canUseForPrivateNodeGroups); err != nil {
if err := checkEachNodeGroup(ngFilter, cfg.NodeGroups, canUseForPrivateNodeGroups); err != nil {
return err
}

Expand All @@ -332,7 +322,7 @@ func doCreateCluster(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg stri
return err
}

if err := CheckEachNodeGroup(ngFilter, cfg, canUseForPrivateNodeGroups); err != nil {
if err := checkEachNodeGroup(ngFilter, cfg.NodeGroups, canUseForPrivateNodeGroups); err != nil {
return err
}

Expand All @@ -345,7 +335,7 @@ func doCreateCluster(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg stri
return err
}

err := CheckEachNodeGroup(ngFilter, cfg, func(_ int, ng *api.NodeGroup) error {
err := checkEachNodeGroup(ngFilter, cfg.NodeGroups, func(_ int, ng *api.NodeGroup) error {
// resolve AMI
if err := ctl.EnsureAMI(meta.Version, ng); err != nil {
return err
Expand Down Expand Up @@ -434,7 +424,7 @@ func doCreateCluster(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg stri
return err
}

err = CheckEachNodeGroup(ngFilter, cfg, func(_ int, ng *api.NodeGroup) error {
err = checkEachNodeGroup(ngFilter, cfg.NodeGroups, func(_ int, ng *api.NodeGroup) error {
// authorise nodes to join
if err = authconfigmap.AddNodeGroup(clientSet, ng); err != nil {
return err
Expand Down
24 changes: 6 additions & 18 deletions pkg/ctl/create/nodegroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func createNodeGroupCmd(g *cmdutils.Grouping) *cobra.Command {

group := g.New(cmd)

exampleNodeGroupName := utils.NodeGroupName("", "")
exampleNodeGroupName := NodeGroupName("", "")

group.InFlagSet("General", func(fs *pflag.FlagSet) {
fs.StringVar(&cfg.Metadata.Name, "cluster", "", "name of the EKS cluster to add the nodegroup to")
Expand Down Expand Up @@ -135,7 +135,7 @@ func doCreateNodeGroups(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg s
return err
}

if err := CheckEachNodeGroup(ngFilter, cfg, NewNodeGroupChecker); err != nil {
if err := checkEachNodeGroup(ngFilter, cfg.NodeGroups, NewNodeGroupChecker); err != nil {
return err
}
} else {
Expand All @@ -153,22 +153,10 @@ func doCreateNodeGroups(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg s
if cmd.Flag(f).Changed {
return fmt.Errorf("cannot use --%s unless a config file is specified via --config-file/-f", f)
}
return nil
}

err := CheckEachNodeGroup(ngFilter, cfg, func(i int, ng *api.NodeGroup) error {
if ng.AllowSSH && ng.SSHPublicKeyPath == "" {
return fmt.Errorf("--ssh-public-key must be non-empty string")
}

// generate nodegroup name or use either flag or argument
if utils.NodeGroupName(ng.Name, nameArg) == "" {
return cmdutils.ErrNameFlagAndArg(ng.Name, nameArg)
}
ng.Name = utils.NodeGroupName(ng.Name, nameArg)

return nil
})
if err != nil {
if err := configureNodeGroups(ngFilter, cfg.NodeGroups, cmd); err != nil {
return err
}

Expand Down Expand Up @@ -203,7 +191,7 @@ func doCreateNodeGroups(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg s
return err
}

err := CheckEachNodeGroup(ngFilter, cfg, func(_ int, ng *api.NodeGroup) error {
err := checkEachNodeGroup(ngFilter, cfg.NodeGroups, func(_ int, ng *api.NodeGroup) error {
// resolve AMI
if err := ctl.EnsureAMI(meta.Version, ng); err != nil {
return err
Expand Down Expand Up @@ -263,7 +251,7 @@ func doCreateNodeGroups(p *api.ProviderConfig, cfg *api.ClusterConfig, nameArg s
return err
}

err = CheckEachNodeGroup(ngFilter, cfg, func(_ int, ng *api.NodeGroup) error {
err = checkEachNodeGroup(ngFilter, cfg.NodeGroups, func(_ int, ng *api.NodeGroup) error {
if updateAuthConfigMap {
// authorise nodes to join
if err = authconfigmap.AddNodeGroup(clientSet, ng); err != nil {
Expand Down
126 changes: 126 additions & 0 deletions pkg/ctl/create/nodegroup_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package create

import (
"fmt"
"github.com/spf13/cobra"
"github.com/weaveworks/eksctl/pkg/ami"
api "github.com/weaveworks/eksctl/pkg/apis/eksctl.io/v1alpha4"
"github.com/weaveworks/eksctl/pkg/utils"
"math/rand"
"time"
)

const (
randNodeGroupNameLength = 8
randNodeGroupNameComponents = "abcdef0123456789"
)

var r = rand.New(rand.NewSource(time.Now().UnixNano()))

// NodeGroupName generates a name string when a and b are empty strings.
// If either a or b are non-empty, it returns whichever is non-empty.
// If neither a nor b are empty, it returns empty name, to indicate
// ambiguous usage.
// It uses a different naming scheme from ClusterName, so that users can
// easily distinguish a cluster name from nodegroup name.
func NodeGroupName(a, b string) string {
return utils.UseNameOrGenerate(a, b, func() string {
name := make([]byte, randNodeGroupNameLength)
for i := 0; i < randNodeGroupNameLength; i++ {
name[i] = randNodeGroupNameComponents[r.Intn(len(randNodeGroupNameComponents))]
}
return fmt.Sprintf("ng-%s", string(name))
})
}

func configureNodeGroups(ngFilter *NodeGroupFilter, nodeGroups []*api.NodeGroup, cmd *cobra.Command) error {
return checkEachNodeGroup(ngFilter, nodeGroups, func(i int, ng *api.NodeGroup) error {
if ng.AllowSSH && ng.SSHPublicKeyPath == "" {
return fmt.Errorf("--ssh-public-key must be non-empty string")
}

if cmd.Flag("ssh-public-key").Changed {
ng.AllowSSH = true
}

// generate nodegroup name or use flag
ng.Name = NodeGroupName(ng.Name, "")

return nil
})
}

func setNodeGroupDefaults(ngFilter *NodeGroupFilter, nodeGroups []*api.NodeGroup) error {
return checkEachNodeGroup(ngFilter, nodeGroups, func(i int, ng *api.NodeGroup) error {

if err := api.ValidateNodeGroup(i, ng); err != nil {
return err
}

// apply defaults
if ng.InstanceType == "" {
ng.InstanceType = api.DefaultNodeType
}
if ng.AMIFamily == "" {
ng.AMIFamily = ami.ImageFamilyAmazonLinux2
}
if ng.AMI == "" {
ng.AMI = ami.ResolverStatic
}

if ng.SecurityGroups == nil {
ng.SecurityGroups = &api.NodeGroupSGs{
AttachIDs: []string{},
}
}
if ng.SecurityGroups.WithLocal == nil {
ng.SecurityGroups.WithLocal = api.NewBoolTrue()
}
if ng.SecurityGroups.WithShared == nil {
ng.SecurityGroups.WithShared = api.NewBoolTrue()
}

// Enable SSH when a key is provided
if ng.SSHPublicKeyPath != "" {
ng.AllowSSH = true
}

if ng.AllowSSH && ng.SSHPublicKeyPath == "" {
ng.SSHPublicKeyPath = defaultSSHPublicKey
}

if ng.VolumeSize > 0 {
if ng.VolumeType == "" {
ng.VolumeType = api.DefaultNodeVolumeType
}
}

if ng.IAM == nil {
ng.IAM = &api.NodeGroupIAM{}
}
if ng.IAM.WithAddonPolicies.ImageBuilder == nil {
ng.IAM.WithAddonPolicies.ImageBuilder = api.NewBoolFalse()
}
if ng.IAM.WithAddonPolicies.AutoScaler == nil {
ng.IAM.WithAddonPolicies.AutoScaler = api.NewBoolFalse()
}
if ng.IAM.WithAddonPolicies.ExternalDNS == nil {
ng.IAM.WithAddonPolicies.ExternalDNS = api.NewBoolFalse()
}

return nil
})
}

// CheckEachNodeGroup iterates over each nodegroup and calls check function
// (this is needed to avoid common goroutine-for-loop pitfall)
func checkEachNodeGroup(ngFilter *NodeGroupFilter, nodeGroups []*api.NodeGroup, check func(i int, ng *api.NodeGroup) error) error {
for i, ng := range nodeGroups {
if ngFilter.Match(ng) {
if err := check(i, ng); err != nil {
return err
}
}
}
return nil
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package create_test
package create

import (
"bytes"
Expand All @@ -7,7 +7,6 @@ import (
. "github.com/onsi/gomega"

api "github.com/weaveworks/eksctl/pkg/apis/eksctl.io/v1alpha4"
. "github.com/weaveworks/eksctl/pkg/ctl/create"
"github.com/weaveworks/eksctl/pkg/printers"
)

Expand Down Expand Up @@ -274,7 +273,7 @@ var _ = Describe("create utils", func() {
f := NewNodeGroupFilter()

names := []string{}
CheckEachNodeGroup(f, cfg, func(i int, nodeGroup *api.NodeGroup) error {
checkEachNodeGroup(f, cfg.NodeGroups, func(i int, nodeGroup *api.NodeGroup) error {
Expect(nodeGroup).To(Equal(cfg.NodeGroups[i]))
names = append(names, nodeGroup.Name)
return nil
Expand All @@ -285,7 +284,7 @@ var _ = Describe("create utils", func() {
cfg.NodeGroups[0].Name = "ng-x0"
cfg.NodeGroups[1].Name = "ng-x1"
cfg.NodeGroups[2].Name = "ng-x2"
CheckEachNodeGroup(f, cfg, func(i int, nodeGroup *api.NodeGroup) error {
checkEachNodeGroup(f, cfg.NodeGroups, func(i int, nodeGroup *api.NodeGroup) error {
Expect(nodeGroup).To(Equal(cfg.NodeGroups[i]))
names = append(names, nodeGroup.Name)
return nil
Expand All @@ -300,7 +299,7 @@ var _ = Describe("create utils", func() {
f := NewNodeGroupFilter()

names := []string{}
CheckEachNodeGroup(f, cfg, func(i int, nodeGroup *api.NodeGroup) error {
checkEachNodeGroup(f, cfg.NodeGroups, func(i int, nodeGroup *api.NodeGroup) error {
Expect(nodeGroup).To(Equal(cfg.NodeGroups[i]))
names = append(names, nodeGroup.Name)
return nil
Expand All @@ -315,7 +314,7 @@ var _ = Describe("create utils", func() {

err = f.ApplyOnlyFilter([]string{"test-ng1?", "te*-ng3?"}, cfg)
Expect(err).ToNot(HaveOccurred())
CheckEachNodeGroup(f, cfg, func(i int, nodeGroup *api.NodeGroup) error {
checkEachNodeGroup(f, cfg.NodeGroups, func(i int, nodeGroup *api.NodeGroup) error {
Expect(nodeGroup).To(Equal(cfg.NodeGroups[i]))
names = append(names, nodeGroup.Name)
return nil
Expand All @@ -332,9 +331,9 @@ var _ = Describe("create utils", func() {
printer := printers.NewJSONPrinter()

names := []string{}
CheckEachNodeGroup(f, cfg, NewNodeGroupChecker)
checkEachNodeGroup(f, cfg.NodeGroups, NewNodeGroupChecker)

CheckEachNodeGroup(f, cfg, func(i int, nodeGroup *api.NodeGroup) error {
checkEachNodeGroup(f, cfg.NodeGroups, func(i int, nodeGroup *api.NodeGroup) error {
Expect(nodeGroup).To(Equal(cfg.NodeGroups[i]))
names = append(names, nodeGroup.Name)
return nil
Expand Down
13 changes: 0 additions & 13 deletions pkg/ctl/create/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,6 @@ func (f *NodeGroupFilter) LogInfo(cfg *api.ClusterConfig) {
}
}

// CheckEachNodeGroup iterates over each nodegroup and calls check function
// (this is needed to avoid common goroutine-for-loop pitfall)
func CheckEachNodeGroup(f *NodeGroupFilter, cfg *api.ClusterConfig, check func(i int, ng *api.NodeGroup) error) error {
for i, ng := range cfg.NodeGroups {
if f.Match(ng) {
if err := check(i, ng); err != nil {
return err
}
}
}
return nil
}

// NewNodeGroupChecker validates a new nodegroup and applies defaults
func NewNodeGroupChecker(i int, ng *api.NodeGroup) error {
if err := api.ValidateNodeGroup(i, ng); err != nil {
Expand Down
30 changes: 4 additions & 26 deletions pkg/utils/namer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package utils

import (
"fmt"
"math/rand"
"time"

"github.com/kubicorn/kubicorn/pkg/namer"
)

func useNameOrGenerate(a, b string, generate func() string) string {
// UseNameOrGenerate picks one of the provided strings or generates a
// new one using the provided generate function
func UseNameOrGenerate(a, b string, generate func() string) string {
if a != "" && b != "" {
return ""
}
Expand All @@ -26,30 +27,7 @@ func useNameOrGenerate(a, b string, generate func() string) string {
// If neither a nor b are empty, it returns empty name, to indicate
// ambiguous usage.
func ClusterName(a, b string) string {
return useNameOrGenerate(a, b, func() string {
return UseNameOrGenerate(a, b, func() string {
return fmt.Sprintf("%s-%d", namer.RandomName(), time.Now().Unix())
})
}

var r = rand.New(rand.NewSource(time.Now().UnixNano()))

const (
randNodeGroupNameLength = 8
randNodeGroupNameComponents = "abcdef0123456789"
)

// NodeGroupName generates a name string when a and b are empty strings.
// If either a or b are non-empty, it returns whichever is non-empty.
// If neither a nor b are empty, it returns empty name, to indicate
// ambiguous usage.
// It uses a different naming scheme from ClusterName, so that users can
// easily distinguish a cluster name from nodegroup name.
func NodeGroupName(a, b string) string {
return useNameOrGenerate(a, b, func() string {
name := make([]byte, randNodeGroupNameLength)
for i := 0; i < randNodeGroupNameLength; i++ {
name[i] = randNodeGroupNameComponents[r.Intn(len(randNodeGroupNameComponents))]
}
return fmt.Sprintf("ng-%s", string(name))
})
}

0 comments on commit 90e1e79

Please sign in to comment.