Skip to content

Commit

Permalink
remove code duplication, add user/pass flags and auth tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed Jun 25, 2024
1 parent 260d06f commit 92a3b46
Show file tree
Hide file tree
Showing 23 changed files with 431 additions and 148 deletions.
69 changes: 23 additions & 46 deletions cmd/createIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"
"time"

avs "github.com/aerospike/avs-client-go"
"github.com/aerospike/avs-client-go/protos"
commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/cobra"
Expand All @@ -21,7 +20,7 @@ import (

//nolint:govet // Padding not a concern for a CLI
var createIndexFlags = &struct {
flags.ClientFlags
clientFlags flags.ClientFlags
namespace string
sets []string
indexName string
Expand All @@ -39,7 +38,7 @@ var createIndexFlags = &struct {
hnswBatchEnabled flags.BoolOptionalFlag
timeout time.Duration
}{
ClientFlags: *flags.NewClientFlags(),
clientFlags: *flags.NewClientFlags(),
storageNamespace: flags.StringOptionalFlag{},
storageSet: flags.StringOptionalFlag{},
hnswMaxEdges: flags.Uint32OptionalFlag{},
Expand Down Expand Up @@ -68,7 +67,7 @@ func newCreateIndexFlagSet() *pflag.FlagSet {
flagSet.Var(&createIndexFlags.hnswBatchMaxRecords, flags.BatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&createIndexFlags.hnswBatchInterval, flags.BatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&createIndexFlags.hnswBatchEnabled, flags.BatchEnabled, commonFlags.DefaultWrapHelpString("Enables batching for index updates. Default is true meaning batching is enabled.")) //nolint:lll // For readability
flagSet.AddFlagSet(createIndexFlags.NewClientFlagSet())
flagSet.AddFlagSet(createIndexFlags.clientFlags.NewClientFlagSet())

return flagSet
}
Expand Down Expand Up @@ -105,56 +104,34 @@ func newCreateIndexCmd() *cobra.Command {
return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
hosts, isLoadBalancer := parseBothHostSeedsFlag(createIndexFlags.Seeds, createIndexFlags.Host)

logger.Debug("parsed flags",
slog.String(flags.Host, createIndexFlags.Host.String()),
slog.String(flags.Seeds, createIndexFlags.Seeds.String()),
slog.String(flags.ListenerName, createIndexFlags.ListenerName.String()),
slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil),
slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil),
slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil),
slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil),
slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil),
slog.String(flags.Namespace, createIndexFlags.namespace),
slog.Any(flags.Sets, createIndexFlags.sets),
slog.String(flags.IndexName, createIndexFlags.indexName),
slog.String(flags.VectorField, createIndexFlags.vectorField),
slog.Uint64(flags.Dimension, uint64(createIndexFlags.dimensions)),
slog.Any(flags.IndexMeta, createIndexFlags.indexMeta),
slog.String(flags.DistanceMetric, createIndexFlags.distanceMetric.String()),
slog.Duration(flags.Timeout, createIndexFlags.timeout),
slog.Any(flags.StorageNamespace, createIndexFlags.storageNamespace.String()),
slog.Any(flags.StorageSet, createIndexFlags.storageSet.String()),
slog.Any(flags.MaxEdges, createIndexFlags.hnswMaxEdges.String()),
slog.Any(flags.Ef, createIndexFlags.hnswEf),
slog.Any(flags.ConstructionEf, createIndexFlags.hnswConstructionEf.String()),
slog.Any(flags.BatchMaxRecords, createIndexFlags.hnswBatchMaxRecords.String()),
slog.Any(flags.BatchInterval, createIndexFlags.hnswBatchInterval.String()),
slog.Any(flags.BatchEnabled, createIndexFlags.hnswBatchEnabled.String()),
append(createIndexFlags.clientFlags.NewSLogAttr(),
slog.String(flags.Namespace, createIndexFlags.namespace),
slog.Any(flags.Sets, createIndexFlags.sets),
slog.String(flags.IndexName, createIndexFlags.indexName),
slog.String(flags.VectorField, createIndexFlags.vectorField),
slog.Uint64(flags.Dimension, uint64(createIndexFlags.dimensions)),
slog.Any(flags.IndexMeta, createIndexFlags.indexMeta),
slog.String(flags.DistanceMetric, createIndexFlags.distanceMetric.String()),
slog.Duration(flags.Timeout, createIndexFlags.timeout),
slog.Any(flags.StorageNamespace, createIndexFlags.storageNamespace.String()),
slog.Any(flags.StorageSet, createIndexFlags.storageSet.String()),
slog.Any(flags.MaxEdges, createIndexFlags.hnswMaxEdges.String()),
slog.Any(flags.Ef, createIndexFlags.hnswEf),
slog.Any(flags.ConstructionEf, createIndexFlags.hnswConstructionEf.String()),
slog.Any(flags.BatchMaxRecords, createIndexFlags.hnswBatchMaxRecords.String()),
slog.Any(flags.BatchInterval, createIndexFlags.hnswBatchInterval.String()),
slog.Any(flags.BatchEnabled, createIndexFlags.hnswBatchEnabled.String()),
)...,
)

ctx, cancel := context.WithTimeout(context.Background(), createIndexFlags.timeout)
defer cancel()

tlsConfig, err := createIndexFlags.NewTLSConfig()
adminClient, err := createClientFromFlags(&createIndexFlags.clientFlags, createIndexFlags.timeout)
if err != nil {
logger.Error("failed to create TLS config", slog.Any("error", err))
return err
}

adminClient, err := avs.NewAdminClient(
ctx, hosts, createIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger,
)
if err != nil {
logger.Error("failed to create AVS client", slog.Any("error", err))
return err
}

cancel()
defer adminClient.Close()

ctx, cancel = context.WithTimeout(context.Background(), createIndexFlags.timeout)
ctx, cancel := context.WithTimeout(context.Background(), createIndexFlags.timeout)
defer cancel()

// Inverted to make it easier to understand
Expand Down
53 changes: 15 additions & 38 deletions cmd/dropIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"log/slog"
"time"

avs "github.com/aerospike/avs-client-go"
commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
Expand All @@ -19,13 +18,13 @@ import (

//nolint:govet // Padding not a concern for a CLI
var dropIndexFlags = &struct {
flags.ClientFlags
namespace string
sets []string
indexName string
timeout time.Duration
clientFlags flags.ClientFlags
namespace string
sets []string
indexName string
timeout time.Duration
}{
ClientFlags: *flags.NewClientFlags(),
clientFlags: *flags.NewClientFlags(),
}

func newDropIndexFlagSet() *pflag.FlagSet {
Expand All @@ -34,7 +33,7 @@ func newDropIndexFlagSet() *pflag.FlagSet {
flagSet.StringArrayVarP(&dropIndexFlags.sets, flags.Sets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability
flagSet.StringVarP(&dropIndexFlags.indexName, flags.IndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability
flagSet.DurationVar(&dropIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(dropIndexFlags.NewClientFlagSet())
flagSet.AddFlagSet(dropIndexFlags.clientFlags.NewClientFlagSet())

return flagSet
}
Expand Down Expand Up @@ -65,43 +64,21 @@ func newDropIndexCommand() *cobra.Command {
},
RunE: func(_ *cobra.Command, _ []string) error {
logger.Debug("parsed flags",
slog.String(flags.Host, dropIndexFlags.Host.String()),
slog.String(flags.Seeds, dropIndexFlags.Seeds.String()),
slog.String(flags.ListenerName, dropIndexFlags.ListenerName.String()),
slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil),
slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil),
slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil),
slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil),
slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil),
slog.String(flags.Namespace, dropIndexFlags.namespace),
slog.Any(flags.Sets, dropIndexFlags.sets),
slog.String(flags.IndexName, dropIndexFlags.indexName),
slog.Duration(flags.Timeout, dropIndexFlags.timeout),
append(dropIndexFlags.clientFlags.NewSLogAttr(),
slog.String(flags.Namespace, dropIndexFlags.namespace),
slog.Any(flags.Sets, dropIndexFlags.sets),
slog.String(flags.IndexName, dropIndexFlags.indexName),
slog.Duration(flags.Timeout, dropIndexFlags.timeout),
)...,
)

hosts, isLoadBalancer := parseBothHostSeedsFlag(dropIndexFlags.Seeds, dropIndexFlags.Host)

ctx, cancel := context.WithTimeout(context.Background(), dropIndexFlags.timeout)
defer cancel()

tlsConfig, err := dropIndexFlags.NewTLSConfig()
adminClient, err := createClientFromFlags(&dropIndexFlags.clientFlags, dropIndexFlags.timeout)
if err != nil {
logger.Error("failed to create TLS config", slog.Any("error", err))
return err
}

adminClient, err := avs.NewAdminClient(
ctx, hosts, dropIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger,
)
if err != nil {
logger.Error("failed to create AVS client", slog.Any("error", err))
return err
}

cancel()
defer adminClient.Close()

ctx, cancel = context.WithTimeout(context.Background(), dropIndexFlags.timeout)
ctx, cancel := context.WithTimeout(context.Background(), dropIndexFlags.timeout)
defer cancel()

err = adminClient.IndexDrop(ctx, dropIndexFlags.namespace, dropIndexFlags.indexName)
Expand Down
22 changes: 20 additions & 2 deletions cmd/flags/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flags

import (
"fmt"
"log/slog"

commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/pflag"
Expand All @@ -11,6 +12,8 @@ type ClientFlags struct {
Host *HostPortFlag
Seeds *SeedsSliceFlag
ListenerName StringOptionalFlag
User StringOptionalFlag
Password commonFlags.PasswordFlag
TLSFlags
}

Expand All @@ -26,9 +29,24 @@ func (cf *ClientFlags) NewClientFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.VarP(cf.Host, Host, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", Seeds))) //nolint:lll // For readability
flagSet.Var(cf.Seeds, Seeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", Host))) //nolint:lll // For readability
flagSet.VarP(&cf.ListenerName, ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments."))

flagSet.VarP(&cf.ListenerName, ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability
flagSet.VarP(&cf.User, User, "U", commonFlags.DefaultWrapHelpString("The AVS user to authenticate with.")) //nolint:lll // For readability
flagSet.VarP(&cf.Password, Password, "P", commonFlags.DefaultWrapHelpString("The AVS password for the specified user.")) //nolint:lll // For readability
flagSet.AddFlagSet(cf.NewTLSFlagSet(commonFlags.DefaultWrapHelpString))

return flagSet
}

func (cf *ClientFlags) NewSLogAttr() []any {
return []any{slog.String(Host, cf.Host.String()),
slog.String(Seeds, cf.Seeds.String()),
slog.String(ListenerName, cf.ListenerName.String()),
slog.String(User, cf.User.String()),
slog.String(Password, cf.Password.String()),
slog.Bool(TLSCaFile, cf.TLSRootCAFile != nil),
slog.Bool(TLSCaPath, cf.TLSRootCAPath != nil),
slog.Bool(TLSCertFile, cf.TLSCertFile != nil),
slog.Bool(TLSKeyFile, cf.TLSKeyFile != nil),
slog.Bool(TLSKeyFilePass, cf.TLSKeyFilePass != nil),
}
}
2 changes: 2 additions & 0 deletions cmd/flags/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ const (
Seeds = "seeds"
Host = "host"
ListenerName = "listener-name"
User = "user"
Password = "password"
Namespace = "namespace"
Sets = "sets"
IndexName = "index-name"
Expand Down
52 changes: 13 additions & 39 deletions cmd/listIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"sync"
"time"

avs "github.com/aerospike/avs-client-go"
"github.com/aerospike/avs-client-go/protos"
commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/cobra"
Expand All @@ -20,21 +19,18 @@ import (
)

var listIndexFlags = &struct {
flags.ClientFlags
verbose bool
timeout time.Duration
clientFlags flags.ClientFlags
verbose bool
timeout time.Duration
}{
ClientFlags: *flags.NewClientFlags(),
clientFlags: *flags.NewClientFlags(),
}

func newListIndexFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.VarP(listIndexFlags.Host, flags.Host, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flags.Seeds))) //nolint:lll // For readability
flagSet.Var(listIndexFlags.Seeds, flags.Seeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flags.Host))) //nolint:lll // For readability
flagSet.VarP(&listIndexFlags.ListenerName, flags.ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability
flagSet.BoolVarP(&listIndexFlags.verbose, flags.Verbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability
flagSet.DurationVar(&listIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(listIndexFlags.NewClientFlagSet())
flagSet.BoolVarP(&listIndexFlags.verbose, flags.Verbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability
flagSet.DurationVar(&listIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(listIndexFlags.clientFlags.NewClientFlagSet())

return flagSet
}
Expand Down Expand Up @@ -62,41 +58,19 @@ func newListIndexCmd() *cobra.Command {
},
RunE: func(_ *cobra.Command, _ []string) error {
logger.Debug("parsed flags",
slog.String(flags.Host, listIndexFlags.Host.String()),
slog.String(flags.Seeds, listIndexFlags.Seeds.String()),
slog.String(flags.ListenerName, listIndexFlags.ListenerName.String()),
slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil),
slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil),
slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil),
slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil),
slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil),
slog.Bool(flags.Verbose, listIndexFlags.verbose),
slog.Duration(flags.Timeout, listIndexFlags.timeout),
append(listIndexFlags.clientFlags.NewSLogAttr(),
slog.Bool(flags.Verbose, listIndexFlags.verbose),
slog.Duration(flags.Timeout, listIndexFlags.timeout),
)...,
)

hosts, isLoadBalancer := parseBothHostSeedsFlag(listIndexFlags.Seeds, listIndexFlags.Host)

ctx, cancel := context.WithTimeout(context.Background(), listIndexFlags.timeout)
defer cancel()

tlsConfig, err := listIndexFlags.NewTLSConfig()
adminClient, err := createClientFromFlags(&listIndexFlags.clientFlags, listIndexFlags.timeout)
if err != nil {
logger.Error("failed to create TLS config", slog.Any("error", err))
return err
}

adminClient, err := avs.NewAdminClient(
ctx, hosts, listIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger,
)
if err != nil {
logger.Error("failed to create AVS client", slog.Any("error", err))
return err
}

cancel()
defer adminClient.Close()

ctx, cancel = context.WithTimeout(context.Background(), listIndexFlags.timeout)
ctx, cancel := context.WithTimeout(context.Background(), listIndexFlags.timeout)
defer cancel()

indexList, err := adminClient.IndexList(ctx)
Expand Down
11 changes: 6 additions & 5 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ func init() {
common.SetupRoot(rootCmd, "aerospike-vector-search", "0.0.0") // TODO: Handle version
viper.SetEnvPrefix("ASVEC")

if err := viper.BindEnv(flags.Host); err != nil {
logger.Error("failed to bind environment variable", slog.Any("error", err))
}
bindEnvs := []string{flags.Host, flags.Seeds, flags.User, flags.Password}

if err := viper.BindEnv(flags.Seeds); err != nil {
logger.Error("failed to bind environment variable", slog.Any("error", err))
// Bind specified flags to ASVEC_*
for _, env := range bindEnvs {
if err := viper.BindEnv(env); err != nil {
panic(fmt.Sprintf("failed to bind environment variable: %s", err))
}
}
}
Loading

0 comments on commit 92a3b46

Please sign in to comment.