diff --git a/cmd/analyze/analyze.go b/cmd/analyze/analyze.go index ba5d469b8a..612f53912a 100644 --- a/cmd/analyze/analyze.go +++ b/cmd/analyze/analyze.go @@ -59,6 +59,7 @@ var AnalyzeCmd = &cobra.Command{ color.Red("Error: %v", err) os.Exit(1) } + defer config.Close() config.RunAnalysis() diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 7a4f69c084..2e195d02b9 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -149,7 +149,7 @@ func init() { // add flag for url addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)") // add flag for endpointName - addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)") + addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, e.g. `endpoint-xxxxxxxxxxxx` (only for amazonbedrock, amazonsagemaker backends)") // add flag for topP addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.") // max tokens @@ -157,7 +157,7 @@ func init() { // add flag for temperature addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // add flag for azure open ai engine/deployment name - addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name") + addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)") //add flag for amazonbedrock region name - addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name") + addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock backend)") } diff --git a/go.mod b/go.mod index c00eb89af2..4867df31bc 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.1 github.com/aws/aws-sdk-go v1.49.15 github.com/cohere-ai/cohere-go v0.2.0 + github.com/google/generative-ai-go v0.5.0 github.com/olekukonko/tablewriter v0.0.5 google.golang.org/api v0.155.0 sigs.k8s.io/controller-runtime v0.16.3 @@ -39,9 +40,11 @@ require ( require ( cloud.google.com/go v0.110.10 // indirect + cloud.google.com/go/ai v0.3.0 // indirect cloud.google.com/go/compute v1.23.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/iam v1.1.5 // indirect + cloud.google.com/go/longrunning v0.5.4 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect diff --git a/go.sum b/go.sum index 9aa65b5b69..c120707e57 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,8 @@ cloud.google.com/go/accesscontextmanager v1.3.0/go.mod h1:TgCBehyr5gNMz7ZaH9xubp cloud.google.com/go/accesscontextmanager v1.4.0/go.mod h1:/Kjh7BBu/Gh83sv+K60vN9QE5NJcd80sU33vIe2IFPE= cloud.google.com/go/accesscontextmanager v1.6.0/go.mod h1:8XCvZWfYw3K/ji0iVnp+6pu7huxoQTLmxAbVjbloTtM= cloud.google.com/go/accesscontextmanager v1.7.0/go.mod h1:CEGLewx8dwa33aDAZQujl7Dx+uYhS0eay198wB/VumQ= +cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg= +cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI= cloud.google.com/go/aiplatform v1.22.0/go.mod h1:ig5Nct50bZlzV6NvKaTwmplLLddFx0YReh9WfTO5jKw= cloud.google.com/go/aiplatform v1.24.0/go.mod h1:67UUvRBKG6GTayHKV8DBv2RtR1t93YRu5B1P3x99mYY= cloud.google.com/go/aiplatform v1.27.0/go.mod h1:Bvxqtl40l0WImSb04d0hXFU7gDOiq9jQmorivIiWcKg= @@ -351,6 +353,8 @@ cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeN cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE= cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc= cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo= +cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg= +cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI= cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE= cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM= cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA= @@ -935,6 +939,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/generative-ai-go v0.5.0 h1:PfzPuSGdsmcSyPG7RIoijcKWZ7/x2kvgyNryvmXMUmA= +github.com/google/generative-ai-go v0.5.0/go.mod h1:8fXQk4w+eyTzFokGGJrBFL0/xwXqm3QNhTqOWyX11zs= github.com/google/gnostic v0.7.0 h1:d7EpuFp8vVdML+y0JJJYiKeOLjKTdH/GvVkLOBWqJpw= github.com/google/gnostic v0.7.0/go.mod h1:IAcUyMl6vtC95f60EZ8oXyqTsOersP6HbwjeG7EyDPM= github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 h1:0VpGH+cDhbDtdcweoyCVsF3fhN8kejK6rFe/2FFX2nU= diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index 732d4e7bb4..25381d7e8a 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -2,15 +2,8 @@ package ai import ( "context" - "encoding/base64" "encoding/json" "fmt" - "strings" - - "github.com/fatih/color" - - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" - "github.com/k8sgpt-ai/k8sgpt/pkg/util" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -19,8 +12,9 @@ import ( // AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service. type AmazonBedRockClient struct { + nopCloser + client *bedrockruntime.BedrockRuntime - language string model string temperature float32 } @@ -91,8 +85,8 @@ func GetRegionOrDefault(region string) string { return BEDROCK_DEFAULT_REGION } -// Configure configures the AmazonBedRockClient with the provided configuration and language. -func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error { +// Configure configures the AmazonBedRockClient with the provided configuration. +func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // Create a new AWS session providerRegion := GetRegionOrDefault(config.GetProviderRegion()) @@ -107,7 +101,6 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error // Create a new BedrockRuntime client a.client = bedrockruntime.New(sess) - a.language = language a.model = GetModelOrDefault(config.GetModel()) a.temperature = config.GetTemperature() @@ -115,7 +108,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error } // GetCompletion sends a request to the model for generating completion based on the provided prompt. -func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { +func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) { // Prepare the input data for the model invocation request := map[string]interface{}{ @@ -152,44 +145,6 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, return output.Completion, nil } -// Parse generates a completion for the provided prompt using the Amazon Bedrock model. -func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { - inputKey := strings.Join(prompt, " ") - // Check for cached data - cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) - - if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { - response, err := cache.Load(cacheKey) - if err != nil { - return "", err - } - - if response != "" { - output, err := base64.StdEncoding.DecodeString(response) - if err != nil { - color.Red("error decoding cached data: %v", err) - return "", nil - } - return string(output), nil - } - } - - response, err := a.GetCompletion(ctx, inputKey, promptTmpl) - - if err != nil { - return "", err - } - - err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) - - if err != nil { - color.Red("error storing value to cache: %v", err) - return "", nil - } - - return response, nil -} - // GetName returns the name of the AmazonBedRockClient. func (a *AmazonBedRockClient) GetName() string { return "amazonbedrock" diff --git a/pkg/ai/amazonsagemaker.go b/pkg/ai/amazonsagemaker.go index e85be89188..cae9e9660e 100644 --- a/pkg/ai/amazonsagemaker.go +++ b/pkg/ai/amazonsagemaker.go @@ -15,15 +15,8 @@ package ai import ( "context" - "encoding/base64" - "fmt" - "strings" - "encoding/json" - - "github.com/fatih/color" - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" - "github.com/k8sgpt-ai/k8sgpt/pkg/util" + "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -31,8 +24,9 @@ import ( ) type SageMakerAIClient struct { + nopCloser + client *sagemakerruntime.SageMakerRuntime - language string model string temperature float32 endpoint string @@ -63,7 +57,7 @@ type Parameters struct { Temperature float64 `json:"temperature"` } -func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error { +func (c *SageMakerAIClient) Configure(config IAIConfig) error { // Create a new AWS session sess := session.Must(session.NewSessionWithOptions(session.Options{ @@ -71,7 +65,6 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error { SharedConfigState: session.SharedConfigEnable, })) - c.language = language // Create a new SageMaker runtime client c.client = sagemakerruntime.New(sess) c.model = config.GetModel() @@ -82,18 +75,13 @@ func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error { return nil } -func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { +func (c *SageMakerAIClient) GetCompletion(_ context.Context, prompt string) (string, error) { // Create a completion request - - if len(promptTmpl) == 0 { - promptTmpl = PromptMap["default"] - } - request := Request{ Inputs: [][]Message{ { {Role: "system", Content: "DEFAULT_PROMPT"}, - {Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)}, + {Role: "user", Content: prompt}, }, }, @@ -142,29 +130,6 @@ func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, pr return content, nil } -func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { - // parse the text with the AI backend - inputKey := strings.Join(prompt, " ") - // Check for cached data - sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey)) - cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc) - - response, err := a.GetCompletion(ctx, inputKey, promptTmpl) - if err != nil { - color.Red("error getting completion: %v", err) - return "", err - } - - err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) - - if err != nil { - color.Red("error storing value to cache: %v", err) - return "", err - } - - return response, nil -} - -func (a *SageMakerAIClient) GetName() string { +func (c *SageMakerAIClient) GetName() string { return "amazonsagemaker" } diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index cd5e073959..eb8ee71728 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -2,27 +2,20 @@ package ai import ( "context" - "encoding/base64" "errors" - "fmt" - "strings" - - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" - "github.com/k8sgpt-ai/k8sgpt/pkg/util" - - "github.com/fatih/color" "github.com/sashabaranov/go-openai" ) type AzureAIClient struct { + nopCloser + client *openai.Client - language string model string temperature float32 } -func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { +func (c *AzureAIClient) Configure(config IAIConfig) error { token := config.GetPassword() baseURL := config.GetBaseURL() engine := config.GetEngine() @@ -40,21 +33,20 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { if client == nil { return errors.New("error creating Azure OpenAI client") } - c.language = lang c.client = client c.model = config.GetModel() c.temperature = config.GetTemperature() return nil } -func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { +func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { // Create a completion request resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, - Content: fmt.Sprintf(default_prompt, c.language, prompt), + Content: prompt, }, }, Temperature: c.temperature, @@ -65,42 +57,6 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt return resp.Choices[0].Message.Content, nil } -func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { - inputKey := strings.Join(prompt, " ") - // Check for cached data - cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) - - if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { - response, err := cache.Load(cacheKey) - if err != nil { - return "", err - } - - if response != "" { - output, err := base64.StdEncoding.DecodeString(response) - if err != nil { - color.Red("error decoding cached data: %v", err) - return "", nil - } - return string(output), nil - } - } - - response, err := a.GetCompletion(ctx, inputKey, promptTmpl) - if err != nil { - return "", err - } - - err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) - - if err != nil { - color.Red("error storing value to cache: %v", err) - return "", nil - } - - return response, nil -} - -func (a *AzureAIClient) GetName() string { +func (c *AzureAIClient) GetName() string { return "azureopenai" } diff --git a/pkg/ai/cohere.go b/pkg/ai/cohere.go index 64a48c54ff..1b6158dfb4 100644 --- a/pkg/ai/cohere.go +++ b/pkg/ai/cohere.go @@ -15,26 +15,20 @@ package ai import ( "context" - "encoding/base64" "errors" - "fmt" - "strings" "github.com/cohere-ai/cohere-go" - "github.com/fatih/color" - - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" - "github.com/k8sgpt-ai/k8sgpt/pkg/util" ) type CohereClient struct { + nopCloser + client *cohere.Client - language string model string temperature float32 } -func (c *CohereClient) Configure(config IAIConfig, language string) error { +func (c *CohereClient) Configure(config IAIConfig) error { token := config.GetPassword() client, err := cohere.CreateClient(token) @@ -50,21 +44,17 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error { if client == nil { return errors.New("error creating Cohere client") } - c.language = language c.client = client c.model = config.GetModel() c.temperature = config.GetTemperature() return nil } -func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl string) (string, error) { +func (c *CohereClient) GetCompletion(_ context.Context, prompt string) (string, error) { // Create a completion request - if len(promptTmpl) == 0 { - promptTmpl = PromptMap["default"] - } resp, err := c.client.Generate(cohere.GenerateOptions{ Model: c.model, - Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt), + Prompt: prompt, MaxTokens: cohere.Uint(2048), Temperature: cohere.Float64(float64(c.temperature)), K: cohere.Int(0), @@ -77,42 +67,6 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str return resp.Generations[0].Text, nil } -func (a *CohereClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { - inputKey := strings.Join(prompt, " ") - // Check for cached data - cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) - - if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { - response, err := cache.Load(cacheKey) - if err != nil { - return "", err - } - - if response != "" { - output, err := base64.StdEncoding.DecodeString(response) - if err != nil { - color.Red("error decoding cached data: %v", err) - return "", nil - } - return string(output), nil - } - } - - response, err := a.GetCompletion(ctx, inputKey, promptTmpl) - if err != nil { - return "", err - } - - err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) - - if err != nil { - color.Red("error storing value to cache: %v", err) - return "", nil - } - - return response, nil -} - -func (a *CohereClient) GetName() string { +func (c *CohereClient) GetName() string { return "cohere" } diff --git a/pkg/ai/googlegenai.go b/pkg/ai/googlegenai.go new file mode 100644 index 0000000000..1b439dc87c --- /dev/null +++ b/pkg/ai/googlegenai.go @@ -0,0 +1,119 @@ +/* +Copyright 2023 The K8sGPT Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ai + +import ( + "context" + "errors" + "fmt" + + "github.com/fatih/color" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" +) + +const googleAIClientName = "google" + +type GoogleGenAIClient struct { + client *genai.Client + + model string + temperature float32 + topP float32 + maxTokens int +} + +func (c *GoogleGenAIClient) Configure(config IAIConfig) error { + ctx := context.Background() + + // Access your API key as an environment variable (see "Set up your API key" above) + token := config.GetPassword() + authOption := option.WithAPIKey(token) + if token[0] == '{' { + authOption = option.WithCredentialsJSON([]byte(token)) + } + + client, err := genai.NewClient(ctx, authOption) + if err != nil { + return fmt.Errorf("creating genai Google SDK client: %w", err) + } + + c.client = client + c.model = config.GetModel() + c.temperature = config.GetTemperature() + c.topP = config.GetTopP() + c.maxTokens = config.GetMaxTokens() + return nil +} + +func (c *GoogleGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + // Available models are at https://ai.google.dev/models e.g.gemini-pro. + model := c.client.GenerativeModel(c.model) + model.SetTemperature(c.temperature) + model.SetTopP(c.topP) + model.SetMaxOutputTokens(int32(c.maxTokens)) + + // Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type. + // Similarly, we could stream the response. For now k8sgpt does not support streaming. + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + return "", err + } + + if len(resp.Candidates) == 0 { + if resp.PromptFeedback.BlockReason == genai.BlockReasonSafety { + for _, r := range resp.PromptFeedback.SafetyRatings { + if !r.Blocked { + continue + } + return "", fmt.Errorf("complection blocked due to %v with probability %v", r.Category.String(), r.Probability.String()) + } + } + return "", errors.New("no complection returned; unknown reason") + } + + // Format output. + // TODO(bwplotka): Provider richer output in certain cases e.g. suddenly finished + // completion based on finish reasons or safety rankings. + got := resp.Candidates[0] + var output string + for _, part := range got.Content.Parts { + switch o := part.(type) { + case genai.Text: + output += string(o) + output += "\n" + default: + color.Yellow("found unsupported AI response part of type %T; ignoring", part) + } + } + + if got.CitationMetadata != nil && len(got.CitationMetadata.CitationSources) > 0 { + output += "Citations:\n" + for _, source := range got.CitationMetadata.CitationSources { + // TODO(bwplotka): Give details around what exactly words could be attributed to the citation. + output += fmt.Sprintf("* %s, %s\n", *source.URI, source.License) + } + } + return output, nil +} + +func (c *GoogleGenAIClient) GetName() string { + return googleAIClientName +} + +func (c *GoogleGenAIClient) Close() { + if err := c.client.Close(); err != nil { + color.Red("googleai client close error: %v", err) + } +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index a603c6b3c9..b4380ce3da 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -15,8 +15,6 @@ package ai import ( "context" - - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" ) var ( @@ -28,25 +26,38 @@ var ( &CohereClient{}, &AmazonBedRockClient{}, &SageMakerAIClient{}, + &GoogleGenAIClient{}, } Backends = []string{ "openai", "localai", "azureopenai", - "noopai", "cohere", "amazonbedrock", "amazonsagemaker", + googleAIClientName, + "noopai", } ) +// IAI is an interface all clients (representing backends) share. type IAI interface { - Configure(config IAIConfig, language string) error - GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) - Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) + // Configure sets up client for given configuration. This is expected to be + // executed once per client life-time (e.g. analysis CLI command invocation). + Configure(config IAIConfig) error + // GetCompletion generates text based on prompt. + GetCompletion(ctx context.Context, prompt string) (string, error) + // GetName returns name of the backend/client. GetName() string + // Close cleans all the resources. No other methods should be used on the + // objects after this method is invoked. + Close() } +type nopCloser struct{} + +func (nopCloser) Close() {} + type IAIConfig interface { GetPassword() string GetModel() string diff --git a/pkg/ai/noopai.go b/pkg/ai/noopai.go index 7183e1758c..e1c4498331 100644 --- a/pkg/ai/noopai.go +++ b/pkg/ai/noopai.go @@ -15,58 +15,21 @@ package ai import ( "context" - "encoding/base64" - "fmt" - "strings" - - "github.com/fatih/color" - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" - "github.com/k8sgpt-ai/k8sgpt/pkg/util" ) type NoOpAIClient struct { - client string - language string - model string + nopCloser } -func (c *NoOpAIClient) Configure(config IAIConfig, language string) error { - token := config.GetPassword() - c.language = language - c.client = fmt.Sprintf("I am a noop client with the token %s ", token) - c.model = config.GetModel() +func (c *NoOpAIClient) Configure(_ IAIConfig) error { return nil } -func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { - // Create a completion request +func (c *NoOpAIClient) GetCompletion(_ context.Context, prompt string) (string, error) { response := "I am a noop response to the prompt " + prompt return response, nil } -func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { - // parse the text with the AI backend - inputKey := strings.Join(prompt, " ") - // Check for cached data - sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey)) - cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc) - - response, err := a.GetCompletion(ctx, inputKey, promptTmpl) - if err != nil { - color.Red("error getting completion: %v", err) - return "", err - } - - err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) - - if err != nil { - color.Red("error storing value to cache: %v", err) - return "", nil - } - - return response, nil -} - -func (a *NoOpAIClient) GetName() string { +func (c *NoOpAIClient) GetName() string { return "noopai" } diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index b37f2aec08..ec097e355b 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -15,22 +15,15 @@ package ai import ( "context" - "encoding/base64" "errors" - "fmt" - "strings" - - "github.com/k8sgpt-ai/k8sgpt/pkg/cache" - "github.com/k8sgpt-ai/k8sgpt/pkg/util" "github.com/sashabaranov/go-openai" - - "github.com/fatih/color" ) type OpenAIClient struct { + nopCloser + client *openai.Client - language string model string temperature float32 } @@ -43,7 +36,7 @@ const ( topP = 1.0 ) -func (c *OpenAIClient) Configure(config IAIConfig, language string) error { +func (c *OpenAIClient) Configure(config IAIConfig) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) @@ -56,24 +49,20 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error { if client == nil { return errors.New("error creating OpenAI client") } - c.language = language c.client = client c.model = config.GetModel() c.temperature = config.GetTemperature() return nil } -func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { +func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { // Create a completion request - if len(promptTmpl) == 0 { - promptTmpl = PromptMap["default"] - } resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, Messages: []openai.ChatCompletionMessage{ { Role: "user", - Content: fmt.Sprintf(promptTmpl, c.language, prompt), + Content: prompt, }, }, Temperature: c.temperature, @@ -88,42 +77,6 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT return resp.Choices[0].Message.Content, nil } -func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { - inputKey := strings.Join(prompt, " ") - // Check for cached data - cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) - - if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { - response, err := cache.Load(cacheKey) - if err != nil { - return "", err - } - - if response != "" { - output, err := base64.StdEncoding.DecodeString(response) - if err != nil { - color.Red("error decoding cached data: %v", err) - return "", nil - } - return string(output), nil - } - } - - response, err := a.GetCompletion(ctx, inputKey, promptTmpl) - if err != nil { - return "", err - } - - err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) - - if err != nil { - color.Red("error storing value to cache: %v", err) - return "", nil - } - - return response, nil -} - -func (a *OpenAIClient) GetName() string { +func (c *OpenAIClient) GetName() string { return "openai" } diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index 6f23ea487c..85541491d8 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -15,12 +15,14 @@ package analysis import ( "context" + "encoding/base64" "errors" "fmt" "reflect" "strings" "sync" + "github.com/fatih/color" openapi_v2 "github.com/google/gnostic/openapiv2" "github.com/k8sgpt-ai/k8sgpt/pkg/ai" "github.com/k8sgpt-ai/k8sgpt/pkg/analyzer" @@ -36,6 +38,7 @@ type Analysis struct { Context context.Context Filters []string Client *kubernetes.Client + Language string AIClient ai.IAI Results []common.Result Errors []string @@ -95,6 +98,7 @@ func NewAnalysis( Context: context.Background(), Filters: filters, Client: client, + Language: language, Namespace: namespace, Cache: cache, Explain: explain, @@ -134,7 +138,7 @@ func NewAnalysis( } aiClient := ai.NewClient(aiProvider.Name) - if err := aiClient.Configure(&aiProvider, language); err != nil { + if err := aiClient.Configure(&aiProvider); err != nil { return nil, err } a.AIClient = aiClient @@ -269,14 +273,14 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error { } texts = append(texts, failure.Text) } - // If the resource `Kind` comes from a "integration plugin", maybe a customized prompt template will be involved. - var promptTemplate string + + promptTemplate := ai.PromptMap["default"] + // If the resource `Kind` comes from an "integration plugin", + // maybe a customized prompt template will be involved. if prompt, ok := ai.PromptMap[analysis.Kind]; ok { promptTemplate = prompt - } else { - promptTemplate = ai.PromptMap["default"] } - parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache, promptTemplate) + result, err := a.getAIResultForSanitizedFailures(texts, promptTemplate) if err != nil { // FIXME: can we avoid checking if output is json multiple times? // maybe implement the progress bar better? @@ -284,23 +288,22 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error { _ = bar.Exit() } - // Check for exhaustion + // Check for exhaustion. if strings.Contains(err.Error(), "status code: 429") { return fmt.Errorf("exhausted API quota for AI provider %s: %v", a.AIClient.GetName(), err) - } else { - return fmt.Errorf("failed while calling AI provider %s: %v", a.AIClient.GetName(), err) } + return fmt.Errorf("failed while calling AI provider %s: %v", a.AIClient.GetName(), err) } if anonymize { for _, failure := range analysis.Error { for _, s := range failure.Sensitive { - parsedText = strings.ReplaceAll(parsedText, s.Masked, s.Unmasked) + result = strings.ReplaceAll(result, s.Masked, s.Unmasked) } } } - analysis.Details = parsedText + analysis.Details = result if output != "json" { _ = bar.Add(1) } @@ -308,3 +311,44 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error { } return nil } + +func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl string) (string, error) { + inputKey := strings.Join(texts, " ") + // Check for cached data. + // TODO(bwplotka): This might depend on model too (or even other client configuration pieces), fix it in later PRs. + cacheKey := util.GetCacheKey(a.AIClient.GetName(), a.Language, inputKey) + + if !a.Cache.IsCacheDisabled() && a.Cache.Exists(cacheKey) { + response, err := a.Cache.Load(cacheKey) + if err != nil { + return "", err + } + + if response != "" { + output, err := base64.StdEncoding.DecodeString(response) + if err == nil { + return string(output), nil + } + color.Red("error decoding cached data; ignoring cache item: %v", err) + } + } + + // Process template. + prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey) + response, err := a.AIClient.GetCompletion(a.Context, prompt) + if err != nil { + return "", err + } + + if err = a.Cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))); err != nil { + color.Red("error storing value to cache; value won't be cached: %v", err) + } + return response, nil +} + +func (a *Analysis) Close() { + if a.AIClient == nil { + return + } + a.AIClient.Close() +} diff --git a/pkg/server/analyze.go b/pkg/server/analyze.go index 5c93d0154a..a27d17a728 100644 --- a/pkg/server/analyze.go +++ b/pkg/server/analyze.go @@ -35,10 +35,11 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) ( false, // Kubernetes Doc disabled in server mode ) config.Context = ctx // Replace context for correct timeouts. - if err != nil { return &schemav1.AnalyzeResponse{}, err } + defer config.Close() + config.RunAnalysis() if i.Explain {