Skip to content

Commit

Permalink
add hnsw params
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed May 22, 2024
1 parent 4ef27d9 commit 5a91438
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 41 deletions.
101 changes: 62 additions & 39 deletions cmd/createIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,32 @@ package cmd
import (
"context"
"log/slog"
"net"

avs "github.com/aerospike/aerospike-proximus-client-go"
"github.com/aerospike/aerospike-proximus-client-go/protos"
commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

var requiredFlags = []string{
flagNameNamespace,
flagNameIndexName,
flagNameDimension,
flagNameDistance,
}

var persistentRequiredFlags = []string{}

const (
flagNameMaxEdges = "hnsw-max-edges"
flagNameConstructionEf = "hnsw-ef-construction"
flagNameEf = "hnsw-ef"
flagNameBatchMaxRecords = "hnsw-batch-max-records"
flagNameBatchInterval = "hnsw-batch-interval"
flagNameBatchDisabled = "hnsw-batch-disabled"
)

// createIndexCmd represents the createIndex command
var createIndexCmd = &cobra.Command{
Use: "index",
Expand All @@ -25,18 +43,18 @@ Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.`,
Run: func(cmd *cobra.Command, args []string) {
host := viper.GetString("host")
port := viper.GetInt("port")
hostPort := avs.NewHostPort(host, port, false)
namespace := viper.GetString("namespace")
sets := viper.GetStringSlice("sets")
indexName := viper.GetString("index-name")
vectorField := viper.GetString("vector-field")
dimension := viper.GetUint32("dimension")
// distanceMetric := viper.GetInt("distance-metric")
indexMeta := viper.GetStringMapString("index-meta")

logger.Debug("Parsed flags", slog.String("host", host), slog.Int("port", port), slog.String("namespace", namespace), slog.Any("sets", sets), slog.String("index-name", indexName), slog.String("vector-field", vectorField), slog.Uint64("dimension", uint64(dimension)), slog.Any("index-meta", indexMeta))
seed := viper.GetString(flagNameSeeds)
port := viper.GetInt(flagNamePort)
hostPort := avs.NewHostPort(seed, port, false)
namespace := viper.GetString(flagNameNamespace)
sets := viper.GetStringSlice(flagNameSets)
indexName := viper.GetString(flagNameIndexName)
vectorField := viper.GetString(flagNameVector)
dimension := viper.GetUint32(flagNameDimension)
indexMeta := viper.GetStringMapString(flagNameIndexMeta)
distanceMetric := viper.GetString(flagNameDistance)

logger.Debug("parsed flags", slog.String("seeds", seed), slog.Int("port", port), slog.String("namespace", namespace), slog.Any("sets", sets), slog.String("index-name", indexName), slog.String("vector-field", vectorField), slog.Uint64("dimension", uint64(dimension)), slog.Any("index-meta", indexMeta))

ctx := context.TODO()

Expand All @@ -48,7 +66,7 @@ to quickly create a Cobra application.`,
}

// TODO: parse cosine
err = adminClient.IndexCreate(ctx, namespace, sets, indexName, vectorField, dimension, protos.VectorDistanceMetric_COSINE, nil, indexMeta)
err = adminClient.IndexCreate(ctx, namespace, sets, indexName, vectorField, dimension, protos.VectorDistanceMetric(protos.VectorDistanceMetric_value[distanceMetric]), nil, indexMeta)
if err != nil {
logger.Error("unable to create index", slog.Any("error", err))
view.Printf("Unable to create index: %v", err)
Expand All @@ -61,33 +79,38 @@ to quickly create a Cobra application.`,

func init() {
createCmd.AddCommand(createIndexCmd)
createIndexCmd.PersistentFlags().IPP("host", "h", net.ParseIP("127.0.0.1"), "TODO")
createIndexCmd.PersistentFlags().IntP("port", "p", 5000, "TODO")
createIndexCmd.Flags().StringP("namespace", "n", "", "TODO")
createIndexCmd.Flags().StringArrayP("sets", "s", nil, "TODO")
createIndexCmd.Flags().StringP("index-name", "i", "", "TODO")
createIndexCmd.Flags().StringP("vector-field", "v", "vector", "TODO")
createIndexCmd.Flags().IntP("dimension", "d", 0, "TODO")
createIndexCmd.Flags().Uint32P("distance-metric", "m", 0, "TODO")
createIndexCmd.Flags().StringToStringP("index-meta", "e", nil, "TODO")
// TODO hnsw metadata

createIndexCmd.MarkFlagRequired("namespace")
createIndexCmd.MarkFlagRequired("set")
createIndexCmd.MarkFlagRequired("index-name")
// createIndexCmd.MarkFlagRequired("vector-field")
createIndexCmd.MarkFlagRequired("dimension")
// createIndexCmd.MarkFlagRequired("distance-metric")
viper.BindPFlags(createIndexCmd.PersistentFlags())
viper.BindPFlags(createIndexCmd.Flags())
persistentFlags := NewFlagSetBuilder(createIndexCmd.PersistentFlags())
flags := NewFlagSetBuilder(createIndexCmd.Flags())

persistentFlags.AddSeedFlag()
persistentFlags.AddPortFlag()

flags.AddNamespaceFlag()
flags.AddSetsFlag()
flags.AddIndexNameFlag()
flags.AddVectorFieldFlag()
flags.AddDimensionFlag()
flags.AddDistanceMetricFlag()
flags.AddIndexMetaFlag()

flags.Uint32(flagNameMaxEdges, 0, commonFlags.DefaultWrapHelpString("Maximum number bi-directional links per HNSW vertex. Greater values of 'm' in general provide better recall for data with high dimensionality, while lower values work well for data with lower dimensionality. The storage space required for the index increases proportionally with 'm'. The default value is 16."))
flags.Uint32(flagNameConstructionEf, 0, commonFlags.DefaultWrapHelpString("The number of candidate nearest neighbors shortlisted during index creation. Larger values provide better recall at the cost of longer index update times. The default is 100."))
flags.Uint32(flagNameEf, 0, commonFlags.DefaultWrapHelpString("The default number of candidate nearest neighbors shortlisted during search. Larger values provide better recall at the cost of longer search times. The default is 100."))
flags.Uint32(flagNameBatchMaxRecords, 0, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000."))
flags.Uint32(flagNameBatchInterval, 0, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000."))
flags.Bool(flagNameBatchDisabled, false, commonFlags.DefaultWrapHelpString("Disables batching for index updates. Default is false meaning batching is enabled."))

// Here you will define your flags and configuration settings.
for _, flag := range requiredFlags {
createIndexCmd.MarkFlagRequired(flag)
}

// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// createIndexCmd.PersistentFlags().String("foo", "", "A help for foo")
for _, flag := range persistentRequiredFlags {
createIndexCmd.MarkPersistentFlagRequired(flag)
}

// TODO hnsw metadata
viper.BindPFlags(createIndexCmd.PersistentFlags())
viper.BindPFlags(createIndexCmd.Flags())

// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// createIndexCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
40 changes: 40 additions & 0 deletions cmd/delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright © 2024 NAME HERE <EMAIL ADDRESS>
*/
package cmd

import (
"fmt"

"github.com/spf13/cobra"
)

// deleteCmd represents the delete command
var deleteCmd = &cobra.Command{
Use: "delete",
Short: "A brief description of your command",
Long: `A longer description that spans multiple lines and likely contains examples
and usage of using your command. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("delete called")
},
}

func init() {
rootCmd.AddCommand(deleteCmd)

// Here you will define your flags and configuration settings.

// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// deleteCmd.PersistentFlags().String("foo", "", "A help for foo")

// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// deleteCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
40 changes: 40 additions & 0 deletions cmd/deleteIndex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright © 2024 NAME HERE <EMAIL ADDRESS>
*/
package cmd

import (
"fmt"

"github.com/spf13/cobra"
)

// deleteIndexCmd represents the deleteIndex command
var deleteIndexCmd = &cobra.Command{
Use: "deleteIndex",
Short: "A brief description of your command",
Long: `A longer description that spans multiple lines and likely contains examples
and usage of using your command. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("deleteIndex called")
},
}

func init() {
deleteCmd.AddCommand(deleteIndexCmd)

// Here you will define your flags and configuration settings.

// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// deleteIndexCmd.PersistentFlags().String("foo", "", "A help for foo")

// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// deleteIndexCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
102 changes: 102 additions & 0 deletions cmd/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package cmd

import (
"fmt"
"net"
"strings"

"github.com/aerospike/aerospike-proximus-client-go/protos"
"github.com/spf13/pflag"
)

const (
flagNameSeeds = "seeds"
flagNamePort = "port"
flagNameNamespace = "namespace"
flagNameSets = "sets"
flagNameIndexName = "index-name"
flagNameVector = "vector-field"
flagNameDimension = "dimension"
flagNameDistance = "distance-metric"
flagNameIndexMeta = "index-meta"
)

type FlagSetBuilder struct {
*pflag.FlagSet
}

func NewFlagSetBuilder(flagSet *pflag.FlagSet) *FlagSetBuilder {
return &FlagSetBuilder{
flagSet,
}
}

// TODO: Should this be a list of IPs? Should we support IP:PORT?
func (fsb *FlagSetBuilder) AddSeedFlag() {
fsb.IPP(flagNameSeeds, "h", net.ParseIP("127.0.0.1"), "The AVS seed host for cluster discovery.")
}

func (fsb *FlagSetBuilder) AddPortFlag() {
fsb.IntP(flagNamePort, "p", 5000, "The AVS seed port for cluster discovery.")
}

func (fsb *FlagSetBuilder) AddNamespaceFlag() {
fsb.StringP(flagNameNamespace, "n", "", "The namespace for the index.")
}

func (fsb *FlagSetBuilder) AddSetsFlag() {
fsb.StringArrayP(flagNameSets, "s", nil, "The sets for the index.")
}

func (fsb *FlagSetBuilder) AddIndexNameFlag() {
fsb.StringP(flagNameIndexName, "i", "", "The name of the index.")

}

func (fsb *FlagSetBuilder) AddVectorFieldFlag() {
fsb.StringP(flagNameVector, "v", "vector-field", "The name of the vector field.")

}

func (fsb *FlagSetBuilder) AddDimensionFlag() {
fsb.IntP(flagNameDimension, "d", 0, "The dimension of the vector field.")

}

func (fsb *FlagSetBuilder) AddDistanceMetricFlag() {
distMetric := DistanceMetricFlag("")
fsb.VarP(&distMetric, "distance-metric", "m", "The distance metric for the index.")
}

func (fsb *FlagSetBuilder) AddIndexMetaFlag() {
fsb.StringToStringP(flagNameIndexMeta, "e", nil, "The metadata for the index.")
}

type DistanceMetricFlag string

// This is just a set of valid VectorDistanceMetrics. The value does not have meaning
var distanceMetricSet = protos.VectorDistanceMetric_value

func (mode *DistanceMetricFlag) Set(val string) error {
val = strings.ToUpper(val)
if val, ok := distanceMetricSet[val]; ok {
*mode = DistanceMetricFlag(val)
return nil
}

return fmt.Errorf("unrecognized distance metric")
}

func (mode *DistanceMetricFlag) Type() string {
names := []string{}

for key := range distanceMetricSet {
names = append(names, key)
}

return strings.Join(names, ",")
}

func (mode *DistanceMetricFlag) String() string {
return string(*mode)
}
3 changes: 1 addition & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Copyright © 2024 NAME HERE <EMAIL ADDRESS>
package cmd

import (
"io"
"log/slog"
"os"

Expand All @@ -13,7 +12,7 @@ import (
"github.com/spf13/viper"
)

var logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
var logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
var view = NewView(os.Stdout)

// rootCmd represents the base command when called without any subcommands
Expand Down

0 comments on commit 5a91438

Please sign in to comment.