Skip to content

Commit

Permalink
feat(validator): impl check template
Browse files Browse the repository at this point in the history
  • Loading branch information
jigsaw373 committed Apr 21, 2023
1 parent e2b84e1 commit b5a47a2
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 38 deletions.
2 changes: 2 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ linters:
- forcetypeassert
- interfacebloat
- musttag
- gci
- gochecknoglobals

issues:
exclude-rules:
Expand Down
66 changes: 44 additions & 22 deletions cmd/cli/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,40 @@ package cli

import (
"context"
"errors"
"fmt"
"regexp"
"strings"

openai "github.com/PullRequestInc/go-gpt3"
azureopenai "github.com/ia-ops/terraform-ai/pkg/gpt3"
"github.com/pkg/errors"
gptEncoder "github.com/samber/go-gpt-3-encoder"
)

const userRole = "user"

var maxTokensMap = map[string]int{
"code-davinci-002": 8001,
"text-davinci-003": 4097,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo": 4096,
"gpt-35-turbo-0301": 4096, // for azure
}
var (
maxTokensMap = map[string]int{
"code-davinci-002": 8001,
"text-davinci-003": 4097,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo": 4096,
"gpt-35-turbo-0301": 4096, // for azure
}
errToken = errors.New("invalid max tokens")
)

type oaiClients struct {
azureClient azureopenai.Client
openAIClient openai.Client
}

func newOAIClients() (oaiClients, error) {
var oaiClient openai.Client
var azureClient azureopenai.Client
var err error
var (
oaiClient openai.Client
azureClient azureopenai.Client
err error
)

if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" {
oaiClient = openai.NewClient(*openAIAPIKey)
Expand All @@ -42,58 +47,72 @@ func newOAIClients() (oaiClients, error) {

azureClient, err = azureopenai.NewClient(*azureOpenAIEndpoint, *openAIAPIKey, *openAIDeploymentName)
if err != nil {
return oaiClients{}, err
return oaiClients{}, fmt.Errorf("error create Azure client: %w", err)
}
}

clients := oaiClients{
azureClient: azureClient,
openAIClient: oaiClient,
}

return clients, nil
}

func gptCompletion(ctx context.Context, client oaiClients, prompts []string, deploymentName string) (string, error) {
temp := float32(*sensitivity)

maxTokens, err := calculateMaxTokens(prompts, deploymentName)
if err != nil {
return "", err
return "", fmt.Errorf("error calculate max token: %w", err)
}

var prompt strings.Builder
fmt.Fprintf(&prompt, "You are a Terraform HCL generator, only generate valid Terraform HCL templates.")
_, err = fmt.Fprintf(&prompt, "You are a Terraform HCL generator, only generate valid Terraform HCL templates.")

if err != nil {
return "", fmt.Errorf("error prompt string builder: %w", err)
}

for _, p := range prompts {
fmt.Fprintf(&prompt, "%s\n", p)
_, err = fmt.Fprintf(&prompt, "%s\n", p)
if err != nil {
return "", fmt.Errorf("error range prompt: %w", err)
}
}

if azureOpenAIEndpoint == nil || *azureOpenAIEndpoint == "" {
if isGptTurbo(deploymentName) {
resp, err := client.openaiGptChatCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
return "", fmt.Errorf("error openai GptChat completion: %w", err)
}

return resp, nil
}

resp, err := client.openaiGptCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
return "", fmt.Errorf("error openai Gpt completion: %w", err)
}

return resp, nil
}

if isGptTurbo35(deploymentName) {
resp, err := client.azureGptChatCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
return "", fmt.Errorf("error azure GptChat completion: %w", err)
}

return resp, nil
}

resp, err := client.azureGptCompletion(ctx, prompt, maxTokens, temp)
if err != nil {
return "", err
return "", fmt.Errorf("error azure Gpt completion: %w", err)
}

return resp, nil
}

Expand All @@ -108,7 +127,7 @@ func isGptTurbo35(deploymentName string) bool {
func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) {
maxTokensFinal, ok := maxTokensMap[deploymentName]
if !ok {
return nil, fmt.Errorf("deploymentName %q not found in max tokens map", deploymentName)
return nil, errors.Wrapf(errToken, "deploymentName %q not found in max tokens map", deploymentName)
}

if *maxTokens > 0 {
Expand All @@ -117,19 +136,22 @@ func calculateMaxTokens(prompts []string, deploymentName string) (*int, error) {

encoder, err := gptEncoder.NewEncoder()
if err != nil {
return nil, err
return nil, fmt.Errorf("error encode gpt: %w", err)
}

// start at 100 since the encoder at times doesn't get it exactly correct
totalTokens := 100

for _, prompt := range prompts {
tokens, err := encoder.Encode(prompt)
if err != nil {
return nil, err
return nil, fmt.Errorf("error encode prompt: %w", err)
}

totalTokens += len(tokens)
}

remainingTokens := maxTokensFinal - totalTokens

return &remainingTokens, nil
}
19 changes: 11 additions & 8 deletions cmd/cli/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ import (
openai "github.com/PullRequestInc/go-gpt3"
azureopenai "github.com/ia-ops/terraform-ai/pkg/gpt3"
"github.com/ia-ops/terraform-ai/pkg/utils"
"github.com/pkg/errors"
)

var errResp = errors.New("invalid response")

func (c *oaiClients) openaiGptCompletion(ctx context.Context, prompt strings.Builder, maxTokens *int, temp float32) (string, error) {
resp, err := c.openAIClient.CompletionWithEngine(ctx, *openAIDeploymentName, openai.CompletionRequest{
Prompt: []string{prompt.String()},
Expand All @@ -19,11 +22,11 @@ func (c *oaiClients) openaiGptCompletion(ctx context.Context, prompt strings.Bui
Temperature: &temp,
})
if err != nil {
return "", err
return "", fmt.Errorf("error openai completion: %w", err)
}

if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices))
}

return resp.Choices[0].Text, nil
Expand All @@ -43,11 +46,11 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt strings
Temperature: &temp,
})
if err != nil {
return "", err
return "", fmt.Errorf("error openai gpt completion: %w", err)
}

if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices))
}

return resp.Choices[0].Message.Content, nil
Expand All @@ -62,11 +65,11 @@ func (c *oaiClients) azureGptCompletion(ctx context.Context, prompt strings.Buil
Temperature: &temp,
})
if err != nil {
return "", err
return "", fmt.Errorf("error azure completion: %w", err)
}

if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices))
}

return resp.Choices[0].Text, nil
Expand All @@ -86,11 +89,11 @@ func (c *oaiClients) azureGptChatCompletion(ctx context.Context, prompt strings.
Temperature: &temp,
})
if err != nil {
return "", err
return "", fmt.Errorf("error azure chatgpt completion: %w", err)
}

if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
return "", errors.Wrapf(errResp, "expected choices to be 1 but received: %d", len(resp.Choices))
}

return resp.Choices[0].Message.Content, nil
Expand Down
20 changes: 12 additions & 8 deletions cmd/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"os/signal"
"strconv"

ter "github.com/ia-ops/terraform-ai/pkg/terraform"
"github.com/manifoldco/promptui"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/walles/env"
)
Expand All @@ -23,6 +25,7 @@ var (
azureOpenAIEndpoint = flag.String("azure-openai-endpoint", env.GetOr("AZURE_OPENAI_ENDPOINT", env.String, ""), "The endpoint for Azure OpenAI service. If provided, Azure OpenAI service will be used instead of OpenAI service.")
requireConfirmation = flag.Bool("require-confirmation", env.GetOr("REQUIRE_CONFIRMATION", strconv.ParseBool, true), "Whether to require confirmation before executing the command. Defaults to true.")
sensitivity = flag.Float64("sensitivity", env.GetOr("SENSITIVITY", env.WithBitSize(strconv.ParseFloat, 64), 0.0), "The sensitivity to use for the model. Range is between 0 and 1. Set closer to 0 if your want output to be more deterministic but less creative. Defaults to 0.0.")
errPrompt = errors.New("invalid prompt")
)

func InitAndExecute() {
Expand All @@ -49,9 +52,9 @@ func RootCmd() *cobra.Command {
return cmd
}

func runCommand(cmd *cobra.Command, args []string) error {
func runCommand(_ *cobra.Command, args []string) error {
if len(args) == 0 {
return fmt.Errorf("prompt must be provided")
return errors.Wrap(errPrompt, "prompt must be provided")
}

return run(args)
Expand All @@ -63,7 +66,7 @@ func run(args []string) error {

oaiClients, err := newOAIClients()
if err != nil {
return err
return fmt.Errorf("error run command: %w", err)
}

completion, err := gptCompletion(ctx, oaiClients, args, *openAIDeploymentName)
Expand All @@ -72,18 +75,19 @@ func run(args []string) error {
}

text := fmt.Sprintf("✨ Attempting to apply the following template: %s", completion)
fmt.Println(text)
log.Println(text)

confirmation, err := getUserConfirmation(*requireConfirmation)
if err != nil {
return err
return fmt.Errorf("error running select: %w", err)
}

if confirmation {
if err = applyTemplate(completion); err != nil {
return err
if err = ter.CheckTemplate(completion); err != nil {
return fmt.Errorf("error check template: %w", err)
}
}

return nil
}

Expand All @@ -99,7 +103,7 @@ func getUserConfirmation(requireConfirmation bool) (bool, error) {

_, result, err := prompt.Run()
if err != nil {
return false, err
return false, fmt.Errorf("error running select: %w", err)
}

return result == "Apply", nil
Expand Down
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ go 1.19

require (
github.com/PullRequestInc/go-gpt3 v1.1.15
github.com/hashicorp/hcl/v2 v2.16.2
github.com/manifoldco/promptui v0.9.0
github.com/pkg/errors v0.9.1
github.com/samber/go-gpt-3-encoder v0.3.1
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.2
Expand All @@ -13,14 +15,19 @@ require (
)

require (
github.com/agext/levenshtein v1.2.1 // indirect
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.7.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/samber/lo v1.37.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/zclconf/go-cty v1.12.1 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/text v0.9.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading

0 comments on commit b5a47a2

Please sign in to comment.