diff --git a/.golangci.yaml b/.golangci.yaml index 195f7e7..122c198 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -38,6 +38,8 @@ linters: - forcetypeassert - interfacebloat - musttag + - gci + - gochecknoglobals issues: exclude-rules: diff --git a/cmd/cli/completion.go b/cmd/cli/completion.go index ec5a251..19a97cd 100644 --- a/cmd/cli/completion.go +++ b/cmd/cli/completion.go @@ -2,25 +2,28 @@ package cli import ( "context" - "errors" "fmt" "regexp" "strings" openai "github.com/PullRequestInc/go-gpt3" azureopenai "github.com/ia-ops/terraform-ai/pkg/gpt3" + "github.com/pkg/errors" gptEncoder "github.com/samber/go-gpt-3-encoder" ) const userRole = "user" -var maxTokensMap = map[string]int{ - "code-davinci-002": 8001, - "text-davinci-003": 4097, - "gpt-3.5-turbo-0301": 4096, - "gpt-3.5-turbo": 4096, - "gpt-35-turbo-0301": 4096, // for azure -} +var ( + maxTokensMap = map[string]int{ + "code-davinci-002": 8001, + "text-davinci-003": 4097, + "gpt-3.5-turbo-0301": 4096, + "gpt-3.5-turbo": 4096, + "gpt-35-turbo-0301": 4096, // for azure + } + errToken = errors.New("invalid max tokens") +) type oaiClients struct { azureClient azureopenai.Client @@ -28,9 +31,11 @@ type oaiClients struct { } func newOAIClients() (oaiClients, error) { - var oaiClient openai.Client - var azureClient azureopenai.Client - var err error + var ( + oaiClient openai.Client + azureClient azureopenai.Client + err error + ) if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" { oaiClient = openai.NewClient(*openAIAPIKey) @@ -42,7 +47,7 @@ func newOAIClients() (oaiClients, error) { azureClient, err = azureopenai.NewClient(*azureOpenAIEndpoint, *openAIAPIKey, *openAIDeploymentName) if err != nil { - return oaiClients{}, err + return oaiClients{}, fmt.Errorf("error create Azure client: %w", err) } } @@ -50,50 +55,64 @@ func newOAIClients() (oaiClients, error) { azureClient: azureClient, openAIClient: oaiClient, } + return clients, nil } func gptCompletion(ctx context.Context, client oaiClients, prompts []string, deploymentName string) (string, error) { temp := float32(*sensitivity) + maxTokens, err := calculateMaxTokens(prompts, deploymentName) if err != nil { - return "", err + return "", fmt.Errorf("error calculate max token: %w", err) } var prompt strings.Builder - fmt.Fprintf(&prompt, "You are a Terraform HCL generator, only generate valid Terraform HCL templates.") + _, err = fmt.Fprintf(&prompt, "You are a Terraform HCL generator, only generate valid Terraform HCL templates.") + + if err != nil { + return "", fmt.Errorf("error prompt string builder: %w", err) + } + for _, p := range prompts { - fmt.Fprintf(&prompt, "%s\n", p) + _, err = fmt.Fprintf(&prompt, "%s\n", p) + if err != nil { + return "", fmt.Errorf("error range prompt: %w", err) + } } if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" { if isGptTurbo(deploymentName) { resp, err := client.openaiGptChatCompletion(ctx, prompt, maxTokens, temp) if err != nil { - return "", err + return "", fmt.Errorf("error openai GptChat completion: %w", err) } + return resp, nil } resp, err := client.openaiGptCompletion(ctx, prompt, maxTokens, temp) if err != nil { - return "", err + return "", fmt.Errorf("error openai Gpt completion: %w", err) } + return resp, nil } if isGptTurbo35(deploymentName) { resp, err := client.azureGptChatCompletion(ctx, prompt, maxTokens, temp) if err != nil { - return "", err + return "", fmt.Errorf("error azure GptChat completion: %w", err) } + return resp, nil } resp, err := client.azureGptCompletion(ctx, prompt, maxTokens, temp) if err != nil { - return "", err + return "", fmt.Errorf("error azure Gpt completion: %w", err) } + return resp, nil } @@ -108,7 +127,7 @@ func isGptTurbo35(deploymentName string) bool { func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) { maxTokensFinal, ok := maxTokensMap[deploymentName] if !ok { - return nil, fmt.Errorf("deploymentName %q not found in max tokens map", deploymentName) + return nil, errors.Wrapf(errToken, "deploymentName %q not found in max tokens map", deploymentName) } if *maxTokens > 0 { @@ -117,19 +136,22 @@ func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) { encoder, err := gptEncoder.NewEncoder() if err != nil { - return nil, err + return nil, fmt.Errorf("error encode gpt: %w", err) } // start at 100 since the encoder at times doesn't get it exactly correct totalTokens := 100 + for _, prompt := range prompts { tokens, err := encoder.Encode(prompt) if err != nil { - return nil, err + return nil, fmt.Errorf("error encode prompt: %w", err) } + totalTokens += len(tokens) } remainingTokens := maxTokensFinal - totalTokens + return &remainingTokens, nil } diff --git a/cmd/cli/openai.go b/cmd/cli/openai.go index 14500ea..a7caaef 100644 --- a/cmd/cli/openai.go +++ b/cmd/cli/openai.go @@ -8,8 +8,11 @@ import ( openai "github.com/PullRequestInc/go-gpt3" azureopenai "github.com/ia-ops/terraform-ai/pkg/gpt3" "github.com/ia-ops/terraform-ai/pkg/utils" + "github.com/pkg/errors" ) +var errResp = errors.New("invalid response") + func (c *oaiClients) openaiGptCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) { resp, err := c.openAIClient.CompletionWithEngine(ctx, *openAIDeploymentName, openai.CompletionRequest{ Prompt: []string{prompt.String()}, @@ -19,11 +22,11 @@ func (c *oaiClients) openaiGptCompletion(ctx context.Context, prompt strings.Bui Temperature: &temp, }) if err != nil { - return "", err + return "", fmt.Errorf("error openai completion: %w", err) } if len(resp.Choices) != 1 { - return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) + return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices)) } return resp.Choices[0].Text, nil @@ -43,11 +46,11 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt strings Temperature: &temp, }) if err != nil { - return "", err + return "", fmt.Errorf("error openai gpt completion: %w", err) } if len(resp.Choices) != 1 { - return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) + return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices)) } return resp.Choices[0].Message.Content, nil @@ -62,11 +65,11 @@ func (c *oaiClients) azureGptCompletion(ctx context.Context, prompt strings.Buil Temperature: &temp, }) if err != nil { - return "", err + return "", fmt.Errorf("error azure completion: %w", err) } if len(resp.Choices) != 1 { - return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) + return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices)) } return resp.Choices[0].Text, nil @@ -86,11 +89,11 @@ func (c *oaiClients) azureGptChatCompletion(ctx context.Context, prompt strings. Temperature: &temp, }) if err != nil { - return "", err + return "", fmt.Errorf("error azure chatgpt completion: %w", err) } if len(resp.Choices) != 1 { - return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices)) + return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices)) } return resp.Choices[0].Message.Content, nil diff --git a/cmd/cli/root.go b/cmd/cli/root.go index 453f5a5..c54d95d 100644 --- a/cmd/cli/root.go +++ b/cmd/cli/root.go @@ -9,7 +9,9 @@ import ( "os/signal" "strconv" + ter "github.com/ia-ops/terraform-ai/pkg/terraform" "github.com/manifoldco/promptui" + "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/walles/env" ) @@ -23,6 +25,7 @@ var ( azureOpenAIEndpoint = flag.String("azure-openai-endpoint", env.GetOr("AZURE_OPENAI_ENDPOINT", env.String, ""), "The endpoint for Azure OpenAI service. If provided, Azure OpenAI service will be used instead of OpenAI service.") requireConfirmation = flag.Bool("require-confirmation", env.GetOr("REQUIRE_CONFIRMATION", strconv.ParseBool, true), "Whether to require confirmation before executing the command. Defaults to true.") sensitivity = flag.Float64("sensitivity", env.GetOr("SENSITIVITY", env.WithBitSize(strconv.ParseFloat, 64), 0.0), "The sensitivity to use for the model. Range is between 0 and 1. Set closer to 0 if your want output to be more deterministic but less creative. Defaults to 0.0.") + errPrompt = errors.New("invalid prompt") ) func InitAndExecute() { @@ -49,9 +52,9 @@ func RootCmd() *cobra.Command { return cmd } -func runCommand(cmd *cobra.Command, args []string) error { +func runCommand(_ *cobra.Command, args []string) error { if len(args) == 0 { - return fmt.Errorf("prompt must be provided") + return errors.Wrap(errPrompt, "prompt must be provided") } return run(args) @@ -63,7 +66,7 @@ func run(args []string) error { oaiClients, err := newOAIClients() if err != nil { - return err + return fmt.Errorf("error run command: %w", err) } completion, err := gptCompletion(ctx, oaiClients, args, *openAIDeploymentName) @@ -72,18 +75,19 @@ func run(args []string) error { } text := fmt.Sprintf("✨ Attempting to apply the following template: %s", completion) - fmt.Println(text) + log.Println(text) confirmation, err := getUserConfirmation(*requireConfirmation) if err != nil { - return err + return fmt.Errorf("error running select: %w", err) } if confirmation { - if err = applyTemplate(completion); err != nil { - return err + if err = ter.CheckTemplate(completion); err != nil { + return fmt.Errorf("error check template: %w", err) } } + return nil } @@ -99,7 +103,7 @@ func getUserConfirmation(requireConfirmation bool) (bool, error) { _, result, err := prompt.Run() if err != nil { - return false, err + return false, fmt.Errorf("error running select: %w", err) } return result == "Apply", nil diff --git a/go.mod b/go.mod index 8817a9f..83c50ae 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.19 require ( github.com/PullRequestInc/go-gpt3 v1.1.15 + github.com/hashicorp/hcl/v2 v2.16.2 github.com/manifoldco/promptui v0.9.0 + github.com/pkg/errors v0.9.1 github.com/samber/go-gpt-3-encoder v0.3.1 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.2 @@ -13,14 +15,19 @@ require ( ) require ( + github.com/agext/levenshtein v1.2.1 // indirect + github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.7.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/samber/lo v1.37.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/zclconf/go-cty v1.12.1 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/sys v0.7.0 // indirect + golang.org/x/text v0.9.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 34a418a..7b57822 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/PullRequestInc/go-gpt3 v1.1.15 h1:pidXZbpqZVW0bp8NBNKDb+/++6PFdYfht9vw2CVpaUs= github.com/PullRequestInc/go-gpt3 v1.1.15/go.mod h1:F9yzAy070LhkqHS2154/IH0HVj5xq5g83gLTj7xzyfw= +github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= +github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6iT90AvPUL1NNfNw= +github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= @@ -13,8 +17,12 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo= github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/hashicorp/hcl/v2 v2.16.2 h1:mpkHZh/Tv+xet3sy3F9Ld4FyI2tUpWe9x3XtPx9f1a0= +github.com/hashicorp/hcl/v2 v2.16.2/go.mod h1:JRmR89jycNkrrqnMmvPDMd56n1rQJ2Q6KocSLCMCXng= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -25,13 +33,18 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/maxbrunsfeld/counterfeiter/v6 v6.2.3/go.mod h1:1ftk08SazyElaaNvmqAfZWGwJzshjCfBXDLoQtPAMNk= +github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM= +github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.9.0/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -55,6 +68,8 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/walles/env v0.0.4 h1:v+cQHLwlASHaybe9VPfRZsmHsdL9HNxfX1yvNkEQsno= github.com/walles/env v0.0.4/go.mod h1:YBVhW14DflZB4j6OO2hyHzjSi3cBDi4lzPXG45hfoTo= +github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= +github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= @@ -77,6 +92,8 @@ golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20200301222351-066e0c02454c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/terraform/validator.go b/pkg/terraform/validator.go new file mode 100644 index 0000000..cb55380 --- /dev/null +++ b/pkg/terraform/validator.go @@ -0,0 +1,20 @@ +package terraform + +import ( + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/hclsyntax" + "github.com/pkg/errors" +) + +var errTemplate = errors.New("invalid terraform template") + +func CheckTemplate(completion string) error { + template := []byte(completion) + _, parseDiags := hclsyntax.ParseConfig(template, "", hcl.Pos{Line: 2, Column: 1}) + + if len(parseDiags) != 0 { + return errors.Wrapf(errTemplate, "expected valid template but: %s", completion) + } + + return nil +}